[mlir] [clang-tools-extra] [llvm] Implement grouped conv interface (PR #80870)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 6 08:40:11 PST 2024
https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/80870
>From ed244c7cbd95294077fa603e022ac234aaf19aa2 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 30 Dec 2023 13:51:52 -0600
Subject: [PATCH 1/5] Implement GroupedConvolutionOpInterface
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 36 +++++
.../Dialect/Linalg/IR/LinalgInterfaces.td | 51 +++++++
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 127 ++++++++++++++++++
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 102 ++++++++++++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 105 +++++++++++++++
mlir/test/Dialect/Linalg/loops.mlir | 8 +-
mlir/test/Dialect/Linalg/named-ops.mlir | 11 ++
mlir/test/Dialect/Linalg/tile-conv.mlir | 2 +-
8 files changed, 436 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 6c8240267e7d05..e2f24432c003b4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -28,6 +28,8 @@ namespace mlir {
namespace linalg {
class IteratorTypeAttr;
class LinalgOp;
+class ConvolutionOpInterface;
+class GroupedConvolutionOpInterface;
namespace detail {
/// Implementation of the method that check if given operands
@@ -115,6 +117,37 @@ bool isaCopyOpInterface(LinalgOp linalgOp);
namespace detail {
+// Common implementations for ConvolutionOpInterface
+namespace convolution_impl {
+// Returns strides as a vector.
+SmallVector<int64_t, 2> getStrides(ConvolutionOpInterface op);
+// Returns dilations as a vector.
+SmallVector<int64_t, 2> getDilations(ConvolutionOpInterface op);
+// Region builder for basic convolution
+void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs);
+// Region builder for basic quantized convolution
+void quantizedRegionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs);
+void getEffects(
+ Operation *op,
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects);
+ParseResult parse(OpAsmParser &parser, OperationState &result,
+ bool isQuantized = false);
+void print(LinalgOp op, OpAsmPrinter &p);
+} // namespace convolution_impl
+
+// Common implementations for GroupedConvolutionOpInterface
+namespace grouped_convolution_impl {
+int64_t getSpatialRank(GroupedConvolutionOpInterface op);
+ArrayAttr createCommonIndexingMaps(MLIRContext *ctx, int64_t numSpatial,
+ int64_t channelPos,
+ const SmallVectorImpl<int64_t> &strides,
+ const SmallVectorImpl<int64_t> &dilations);
+ArrayAttr getIteratorTypes(GroupedConvolutionOpInterface op);
+} // namespace grouped_convolution_impl
+
/// Returns true if the block contains a contraction of the following form:
///
/// %0 = <elemwise>(permutation-of(cu(block-argument-0),
@@ -171,6 +204,9 @@ LogicalResult verifyContractionInterface(Operation *op);
/// Verify that `op` conforms to the ConvolutionOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op);
+/// Verify that `op` conforms to the GroupedConvolutionOpInterface.
+LogicalResult verifyGroupedConvolutionInterface(Operation *op);
+
/// Verify that `op` conforms to the FillOpInterface.
LogicalResult verifyFillInterface(Operation *op);
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9b..170ebf9d43030f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -175,6 +175,57 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
return $_op.getOperation()->getOperand(1);
}]
>,
+ InterfaceMethod<
+ /*desc=*/"Return the spatial rank.",
+ /*retTy=*/"int64_t",
+ /*methodName=*/"getSpatialRank",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ // Most convolution's inputs have batch, channel and spatial dims
+ return cast<ShapedType>(image().getType()).getRank() - 2;
+ }]
+ >
+ ];
+}
+
+def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInterface", [
+ LinalgConvolutionOpInterface]> {
+ let description = [{
+ A grouped convolution is defined in general terms:
+ 1. It is a convolution as defined by `ConvolutionOpInterface`.
+ 2. Operands have a the following distinct dimensions (excluding batch in input/output): group, channel, spatial
+ 3. `input_rank == kernel_rank == output_rank` (including batch in input/output)
+ 4. Reductions are along the input channel and spatial dimensions while group, output channel
+ and output spatial dimensions are parallel.
+ }];
+ let cppNamespace = "::mlir::linalg";
+ let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }];
+ let methods = [
+ InterfaceMethod<[{
+ Returns indexing maps for any spatial dimension.
+ }],
+ "::mlir::ArrayAttr", "getIteratorTypes", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return detail::grouped_convolution_impl::getIteratorTypes($_op);
+ }]>,
+ InterfaceMethod<[{
+ Returns strides.
+ }],
+ "::llvm::SmallVector<int64_t, 2>", "getStridesVector", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return detail::convolution_impl::getStrides($_op);
+ }]>,
+ InterfaceMethod<[{
+ Returns dilations.
+ }],
+ "::llvm::SmallVector<int64_t, 2>", "getDilationsVector", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return detail::convolution_impl::getDilations($_op);
+ }]>
];
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 751edd02288301..44db786a64595e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -384,6 +384,133 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// GroupedConvNDOp ops.
+//===----------------------------------------------------------------------===//
+
+def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
+ [AttrSizedOperandSegments, LinalgGroupedConvolutionOpInterface]> {
+
+ let summary = [{
+ Performs N-D grouped convolution with switchable channel position; either first or last.
+ }];
+ let description = [{
+ Allows any number of spatial dimensions but treats all of them as contiguous. Throughout, `S`,
+ will represent all spatial dimensions. Operand layouts are determined by the `channel_first`
+ `bool` attritbute. When placing the channel dim first or last, the batch dim is excluded. In
+ any case, the channel and spatial dims are in the same relative order for all operands.
+
+ Domain: N, G, F, S, C, KS
+
+ Layouts:
+ `channel_first == true`:
+ Input: `NGCS`
+ Kernel: `FS`
+ Output: `NGFS`
+
+ `channel_first == false`:
+ Input: `NSGC`
+ Kernel: `SGFC`
+ Output: `NSGF`
+
+ }];
+
+ let arguments = (ins
+ Variadic<TensorOrMemref>:$inputs,
+ Variadic<TensorOrMemref>:$inits,
+ DefaultValuedAttr<BoolAttr, "true">:$channel_first,
+ OptionalAttr<I64ElementsAttr>:$strides,
+ OptionalAttr<I64ElementsAttr>:$dilations
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "Value":$input, "Value":$filter, "Value":$init, CArg<"bool", "true">:$channel_first,
+ CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
+ int64_t numSpatialDims = input.getType().cast<ShapedType>().getRank() - 3;
+ if (strides.empty())
+ strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+ if (dilations.empty())
+ dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+ $_state.addAttribute(getStridesAttrName($_state.name),
+ ::mlir::DenseElementsAttr::get(
+ ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides));
+ $_state.addAttribute(getDilationsAttrName($_state.name),
+ ::mlir::DenseElementsAttr::get(
+ ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations));
+ buildStructuredOp($_builder, $_state, std::nullopt, {input, filter}, init,
+ attributes, GroupedConvNDOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter,
+ "Value":$init, CArg<"bool", "true">:$channel_first,
+ CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
+ int64_t numSpatialDims = input.getType().cast<ShapedType>().getRank() - 3;
+ if (strides.empty())
+ strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+ if (dilations.empty())
+ dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+ $_state.addAttribute(getStridesAttrName($_state.name),
+ ::mlir::DenseElementsAttr::get(
+ ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides));
+ $_state.addAttribute(getDilationsAttrName($_state.name),
+ ::mlir::DenseElementsAttr::get(
+ ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations));
+ buildStructuredOp($_builder, $_state, resultTensorTypes,
+ {input, filter}, init, attributes, GroupedConvNDOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter,
+ "Value":$init, "Attribute":$channel_first, "Attribute":$strides, "Attribute":$dilations,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute(getChannelFirstAttrName($_state.name), channel_first);
+ $_state.addAttribute(getStridesAttrName($_state.name), strides);
+ $_state.addAttribute(getDilationsAttrName($_state.name), dilations);
+ buildStructuredOp($_builder, $_state, resultTensorTypes, {input, filter}, init,
+ attributes, GroupedConvNDOp::getRegionBuilder());
+ }]>
+ ];
+
+ // TODO: Figure out how to move this to the interface
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ void print(::mlir::OpAsmPrinter &printer) {
+ return detail::convolution_impl::print(*this, printer);
+ }
+ static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ return detail::convolution_impl::parse(parser, result);
+ }
+ static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+ mlir::ArrayRef<mlir::NamedAttribute>)>
+ getRegionBuilder() {
+ return detail::convolution_impl::regionBuilder;
+ }
+ // Implement functions necessary for DestinationStyleOpInterface.
+ MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+
+ // Implement functions necessary for LinalgOp.
+ ArrayAttr getIndexingMaps();
+
+ // Implement functions necessary for GroupedConvolutionOpInterface
+ int64_t getSpatialRank() {
+ return detail::grouped_convolution_impl::getSpatialRank(*this);
+ }
+
+ int64_t getChannelPosition() {
+ return (getChannelFirstAttr().getValue()) ? 1 : getSpatialRank() + 1;
+ }
+ }];
+}
//===----------------------------------------------------------------------===//
// Transpose op.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ba419d32f22a3e..ba28c4ed954970 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -638,6 +638,108 @@ enum class MatchConvolutionResult {
};
} // namespace mlir::linalg::detail
+SmallVector<int64_t, 2>
+mlir::linalg::detail::convolution_impl::getStrides(ConvolutionOpInterface op) {
+ auto maybeStridesAttr = op->getAttrOfType<DenseIntElementsAttr>("strides");
+ if (!maybeStridesAttr) {
+ OpBuilder builder(op.getContext());
+ return SmallVector<int64_t, 2>(op.getSpatialRank(), 1);
+ }
+ return llvm::to_vector(maybeStridesAttr.getValues<int64_t>());
+}
+
+SmallVector<int64_t, 2> mlir::linalg::detail::convolution_impl::getDilations(
+ ConvolutionOpInterface op) {
+ auto maybeDilationsAttr =
+ op->getAttrOfType<DenseIntElementsAttr>("dilations");
+ if (!maybeDilationsAttr) {
+ OpBuilder builder(op.getContext());
+ return SmallVector<int64_t, 2>(op.getSpatialRank(), 1);
+ }
+ return llvm::to_vector(maybeDilationsAttr.getValues<int64_t>());
+}
+
+int64_t mlir::linalg::detail::grouped_convolution_impl::getSpatialRank(
+ GroupedConvolutionOpInterface op) {
+ return cast<ShapedType>(op.image().getType()).getRank() - 3;
+}
+
+ArrayAttr mlir::linalg::detail::grouped_convolution_impl::getIteratorTypes(
+ GroupedConvolutionOpInterface op) {
+ int64_t numSpatialDims = op.getSpatialRank();
+ SmallVector<Attribute> iteratorTypes(
+ 3 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par));
+ SmallVector<Attribute> reductions(
+ numSpatialDims + 1, IteratorTypeAttr::get(op.getContext(), red));
+ iteratorTypes.insert(iteratorTypes.end(), reductions.begin(),
+ reductions.end());
+
+ return Builder(op.getContext()).getArrayAttr(iteratorTypes);
+}
+
+ArrayAttr
+mlir::linalg::detail::grouped_convolution_impl::createCommonIndexingMaps(
+ MLIRContext *ctx, int64_t numSpatial, int64_t channelPos,
+ const SmallVectorImpl<int64_t> &strides,
+ const SmallVectorImpl<int64_t> &dilations) {
+
+ // Domain: (n, g, f, os, c, ks)
+ AffineExpr n = getAffineDimExpr(0, ctx);
+ AffineExpr g = getAffineDimExpr(1, ctx);
+ AffineExpr f = getAffineDimExpr(2, ctx);
+ SmallVector<AffineExpr> s(
+ llvm::map_range(llvm::seq<int64_t>(3, numSpatial + 3),
+ [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+ AffineExpr c = getAffineDimExpr(numSpatial + 3, ctx);
+ SmallVector<AffineExpr> ks(llvm::map_range(
+ llvm::seq<int64_t>(numSpatial + 4, 2 * (numSpatial + 1) + 2),
+ [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+
+ // Initialze operand accesses in nw order and insert c according to channel
+ // position
+ SmallVector<AffineExpr> inExprs = {n}, outExprs = {n};
+ SmallVector<AffineExpr> gc = {g, c};
+ SmallVector<AffineExpr> gf = {g, f};
+ SmallVector<AffineExpr> gfc = {g, f, c};
+ for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, strides, dilations)) {
+ inExprs.push_back(sp * st + ksp * di);
+ outExprs.push_back(sp);
+ }
+ SmallVector<AffineExpr> kExprs(ks);
+ inExprs.insert(inExprs.begin() + channelPos, gc.begin(), gc.end());
+ kExprs.insert(channelPos == 0 ? kExprs.begin()
+ : kExprs.begin() + channelPos - 1,
+ gfc.begin(), gfc.end());
+ outExprs.insert(outExprs.begin() + channelPos, gf.begin(), gf.end());
+ SmallVector<AffineMap> maps(
+ {AffineMap::get(4 + 2 * numSpatial, 0, inExprs, ctx),
+ AffineMap::get(4 + 2 * numSpatial, 0, kExprs, ctx),
+ AffineMap::get(4 + 2 * numSpatial, 0, outExprs, ctx)});
+
+ return Builder(ctx).getAffineMapArrayAttr(maps);
+}
+
+LogicalResult
+mlir::linalg::detail::verifyGroupedConvolutionInterface(Operation *op) {
+ if (failed(verifyConvolutionInterface(op)))
+ return failure();
+ if (GroupedConvolutionOpInterface conv =
+ dyn_cast<GroupedConvolutionOpInterface>(op)) {
+ const auto imageType = conv.image().getType().dyn_cast<ShapedType>();
+ const auto imageRank = imageType.getRank();
+ const auto kernelRank =
+ conv.filter().getType().cast<ShapedType>().getRank();
+ const auto initType =
+ cast<LinalgOp>(op).getDpsInits()[0].getType().dyn_cast<ShapedType>();
+ const auto initRank = initType.getRank();
+ if (imageRank != kernelRank || imageRank != initRank)
+ return op->emitError(
+ "Rank relationship must be `in_rank == out_rank == kernel_rank`");
+ return success();
+ }
+ return failure();
+}
+
mlir::linalg::detail::MatchConvolutionResult
mlir::linalg::detail::isConvolutionInterfaceImpl(
Operation *op, ConvolutionDimensions *dimensions) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b68aa77fd83a1c..c31b17082b8900 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1663,6 +1663,111 @@ LogicalResult ReduceOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ConvolutionOpInterface
+//===----------------------------------------------------------------------===//
+
+// There must be a way to avoid defining the following 3 functions
+ParseResult mlir::linalg::detail::convolution_impl::parse(
+ OpAsmParser &parser, OperationState &result, bool isQuantized) {
+ if (isQuantized)
+ return parseNamedStructuredOp(
+ parser, result, 5,
+ mlir::linalg::detail::convolution_impl::quantizedRegionBuilder);
+ return parseNamedStructuredOp(
+ parser, result, 3, mlir::linalg::detail::convolution_impl::regionBuilder);
+}
+
+void mlir::linalg::detail::convolution_impl::print(LinalgOp op,
+ OpAsmPrinter &p) {
+ printNamedStructuredOp(p, op.getOperation(), op.getDpsInputs(),
+ op.getDpsInits());
+}
+
+// Build {mul, add} region for convolution
+void mlir::linalg::detail::convolution_impl::regionBuilder(
+ ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "ConvolutionInterface regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+ SmallVector<Value> yields;
+
+ Value value1 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
+void mlir::linalg::detail::convolution_impl::quantizedRegionBuilder(
+ ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 5 &&
+ "ConvolutionInterface regionBuilder expects 5 args");
+ RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+ Value value1 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+ block.getArgument(0));
+ Value value2 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+ block.getArgument(2));
+ Value value3 = helper.buildBinaryFn(BinaryFn::sub, value1, value2);
+ Value value4 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+ block.getArgument(1));
+ Value value5 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+ block.getArgument(3));
+ Value value6 = helper.buildBinaryFn(BinaryFn::sub, value4, value5);
+ Value value7 = helper.buildBinaryFn(BinaryFn::mul, value3, value6);
+ Value value8 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(4), value7);
+ helper.yieldOutputs({value8});
+}
+
+void mlir::linalg::detail::convolution_impl::getEffects(
+ Operation *op,
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (!isa<ConvolutionOpInterface>(op))
+ return;
+ if (LinalgOp linalgOp = dyn_cast<LinalgOp>(op)) {
+ if (linalgOp.hasTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, linalgOp.getOperation()->getResults(),
+ linalgOp.getDpsInputs(), linalgOp.getDpsInits());
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// GroupedConvNDOp
+//===----------------------------------------------------------------------===//
+
+void GroupedConvNDOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ return detail::convolution_impl::getEffects(*this, effects);
+}
+
+ArrayAttr GroupedConvNDOp::getIndexingMaps() {
+ ArrayAttr cached = (*this)->getAttrOfType<ArrayAttr>(
+ LinalgDialect::kMemoizedIndexingMapsAttrName);
+ if (cached)
+ return cached;
+
+ cached = detail::grouped_convolution_impl::createCommonIndexingMaps(
+ getContext(), getSpatialRank(), getChannelPosition(), getStridesVector(),
+ getDilationsVector());
+
+ (*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+ return cached;
+}
+
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 8c13422fd63833..640680483130d7 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -1,12 +1,10 @@
-// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s
-// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefix=CHECKPARALLEL %s
+// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s --check-prefixes=COMMON,CHECK
+// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefixes=COMMON,CHECKPARALLEL %s
// Test that we can lower all the way to LLVM without crashing, don't check results here.
// RUN: mlir-opt %s -convert-linalg-to-loops -test-lower-to-llvm -o=/dev/null 2>&1
-// CHECK: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
-
-// CHECKPARALLEL: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// COMMON: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
func.func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 29977a71dbb864..bd728edd1ec715 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1,5 +1,16 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+// -----
+
+// CHECK-LABEL: func @gen_grouped_1D_channel_first_memref
+func.func @gen_grouped_1D_channel_first_memref(%arg0: memref<64x8x16x10xf32>, %arg1: memref<8x32x16x3xf32>, %arg2: memref<64x8x32x8xf32>) {
+ // CHECK: grouped_conv_nd {{.*}}channel_first = true
+ linalg.grouped_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x8x16x10xf32>, memref<8x32x16x3xf32>) outs(%arg2: memref<64x8x32x8xf32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @depthwise_conv_1d_nwc_wcm
func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> {
%zero = arith.constant 0.000000e+00 : f32
diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir
index 4a940f12662e6c..1662f5c45fe804 100644
--- a/mlir/test/Dialect/Linalg/tile-conv.mlir
+++ b/mlir/test/Dialect/Linalg/tile-conv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -transform-interpreter -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter -canonicalize -split-input-file | FileCheck %s
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)>
>From dbd721546e2f2811f4ed6a118412c4cd198530d6 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 31 Dec 2023 09:58:37 -0600
Subject: [PATCH 2/5] Add bufferization test
---
mlir/test/Dialect/Linalg/bufferize.mlir | 22 ++++++++++++++++++++++
1 file changed, 22 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 29f27e6838e661..9d3444fe2ce9cb 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -217,3 +217,25 @@ func.func public @main(%arg0: tensor<2x3xi1>) -> tensor<6xi64> {
%3 = func.call @csum(%2) : (tensor<6xi64>) -> tensor<6xi64>
return %3 : tensor<6xi64>
}
+
+
+
+// -----
+
+// CHECK-LABEL: func @gen_grouped_3D_channel_first_tensor(
+// CHECK-SAME: %[[ARG0_TENSOR:.*]]: tensor<64x2x16x26x26x26xf32>,
+// CHECK-SAME: %[[ARG1_TENSOR:.*]]: tensor<2x20x16x3x3x3xf32>,
+// CHECK-SAME: %[[ARG2_TENSOR:.*]]: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> {
+// CHECK-DAG: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0_TENSOR]] : memref<64x2x16x26x26x26xf32>
+// CHECK-DAG: %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<2x20x16x3x3x3xf32>
+// CHECK-DAG: %[[ARG2_MEMREF:.*]] = bufferization.to_memref %[[ARG2_TENSOR]] : memref<64x2x20x8x8x8xf32>
+// CHECK-DAG: %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<64x2x20x8x8x8xf32>
+// CHECK: memref.copy %[[ARG2_MEMREF]], %[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32> to memref<64x2x20x8x8x8xf32>
+// CHECK: linalg.grouped_conv_nd
+// CHECK-SAME: {channel_first = true, dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64>}
+// CHECK-SAME: ins(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]] : memref<64x2x16x26x26x26xf32>, memref<2x20x16x3x3x3xf32>)
+// CHECK-SAME: outs(%[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32>)
+func.func @gen_grouped_3D_channel_first_tensor(%arg0: tensor<64x2x16x26x26x26xf32>, %arg1: tensor<2x20x16x3x3x3xf32>, %arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> {
+ %0 = linalg.grouped_conv_nd {channel_first = true, strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x2x16x26x26x26xf32>, tensor<2x20x16x3x3x3xf32>) outs(%arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32>
+ return %0 : tensor<64x2x20x8x8x8xf32>
+}
>From 162904694b84924261c6ddd1f9f4430df66f6a10 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 31 Dec 2023 10:46:45 -0600
Subject: [PATCH 3/5] Add tiling regression test
---
mlir/test/Dialect/Linalg/tile-conv.mlir | 60 +++++++++++++++++++++++++
1 file changed, 60 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir
index 1662f5c45fe804..50065ccb18cb84 100644
--- a/mlir/test/Dialect/Linalg/tile-conv.mlir
+++ b/mlir/test/Dialect/Linalg/tile-conv.mlir
@@ -41,3 +41,63 @@ module attributes {transform.with_named_sequence} {
// CHECK: linalg.conv_2d
// CHECK-SAME: ins(%[[SVIN]], %[[SVKER]]
// CHECK-SAME: outs(%[[SVOUT]]
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 6)>
+// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0)[s0] -> (d0 + s0 - 1)>
+
+func.func @grouped_conv_2D(%arg0 : memref<?x?x?x?x?xf32>, %arg1 : memref<?x?x?x?x?xf32>, %arg2 : memref<?x?x?x?x?xf32>) {
+ linalg.grouped_conv_nd ins(%arg0, %arg1 : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>) outs(%arg2 : memref<?x?x?x?x?xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.grouped_conv_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop:5 = transform.structured.tile_using_for %0 [2, 3, 4, 5, 6] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: func @grouped_conv_2D
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?x?x?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?x?x?x?xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG: %[[BATCH:.*]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[GROUPS:.*]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[IN_CHANNELS:.*]] = memref.dim %[[ARG0]], %[[C2]]
+// CHECK-DAG: %[[OUT_CHANNELS:.*]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[KW:.*]] = memref.dim %[[ARG1]], %[[C3]]
+// CHECK-DAG: %[[KH:.*]] = memref.dim %[[ARG1]], %[[C4]]
+// CHECK-DAG: %[[W:.*]] = memref.dim %[[ARG2]], %[[C3]]
+// CHECK-DAG: %[[H:.*]] = memref.dim %[[ARG2]], %[[C4]]
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[BATCH]] step %[[C2]]
+// CHECK: %[[T4:.*]] = affine.min #[[MAP0]](%[[I]])[%[[BATCH]]]
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[GROUPS]] step %[[C3]]
+// CHECK: %[[T5:.*]] = affine.min #[[MAP1]](%[[J]])[%[[GROUPS]]]
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[OUT_CHANNELS]] step %[[C4]]
+// CHECK-DAG: %[[T6:.*]] = affine.min #[[MAP2]](%[[K]])[%[[OUT_CHANNELS]]]
+// CHECK: scf.for %[[L:.*]] = %[[C0]] to %[[W]] step %[[C5]]
+// CHECK-DAG: %[[T7:.*]] = affine.min #[[MAP3]](%[[L]])[%[[W]]]
+// CHECK: scf.for %[[M:.*]] = %[[C0]] to %[[H]] step %[[C6]]
+// CHECK-DAG: %[[T8:.*]] = affine.min #[[MAP4]](%[[M]])[%[[H]]]
+// CHECK-DAG: %[[T9:.*]] = affine.apply #[[MAP5]](%[[T7]])[%[[KW]]]
+// CHECK-DAG: %[[T10:.*]] = affine.apply #[[MAP5]](%[[T8]])[%[[KH]]]
+// CHECK-DAG: %[[SVIN:.*]] = memref.subview %[[ARG0]][%[[I]], %[[J]], 0, %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[IN_CHANNELS]], %[[T9]], %[[T10]]]
+// CHECK-DAG: %[[SVKER:.*]] = memref.subview %[[ARG1]][%[[J]], %[[K]], 0, 0, 0] [%[[T5]], %[[T6]], %[[IN_CHANNELS]], %[[KW]], %[[KH]]]
+// CHECK-DAG: %[[SVOUT:.*]] = memref.subview %[[ARG2]][%[[I]], %[[J]], %[[K]], %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[T6]], %[[T7]], %[[T8]]]
+// CHECK: linalg.grouped_conv_nd {channel_first = true}
+// CHECK-SAME: ins(%[[SVIN]], %[[SVKER]]
+// CHECK-SAME: outs(%[[SVOUT]]
\ No newline at end of file
>From 1969a22b7b5b50dfeaf27a56237fdbaf99947ae2 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 31 Dec 2023 11:32:46 -0600
Subject: [PATCH 4/5] Add interface methods for getting channel and group sizes
---
.../Dialect/Linalg/IR/LinalgInterfaces.td | 29 +++++++++++++++++++
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 2 +-
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 6 ++--
3 files changed, 33 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 170ebf9d43030f..3feadfa17a2e5f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -202,6 +202,35 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
let cppNamespace = "::mlir::linalg";
let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }];
let methods = [
+ InterfaceMethod<[{
+ Returns the channel position.
+ }],
+ "int64_t", "getChannelPosition", (ins)
+ >,
+ InterfaceMethod<[{
+ Get number of groups.
+ }],
+ "int64_t", "getNumGroups", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getChannelPosition() - 1];
+ }]>,
+ InterfaceMethod<[{
+ Get number of input channels.
+ }],
+ "int64_t", "getNumInputChannels", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getChannelPosition()];
+ }]>,
+ InterfaceMethod<[{
+ Get number of output channels.
+ }],
+ "int64_t", "getNumOutputChannels", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return cast<ShapedType>($_op.getDpsInits()[0].getType()).getShape()[$_op.getChannelPosition()];
+ }]>,
InterfaceMethod<[{
Returns indexing maps for any spatial dimension.
}],
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 44db786a64595e..bc7e7ba004c9b1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -507,7 +507,7 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
}
int64_t getChannelPosition() {
- return (getChannelFirstAttr().getValue()) ? 1 : getSpatialRank() + 1;
+ return (getChannelFirstAttr().getValue()) ? 2 : getSpatialRank() + 2;
}
}];
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ba28c4ed954970..c736bd064bded2 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -706,11 +706,11 @@ mlir::linalg::detail::grouped_convolution_impl::createCommonIndexingMaps(
outExprs.push_back(sp);
}
SmallVector<AffineExpr> kExprs(ks);
- inExprs.insert(inExprs.begin() + channelPos, gc.begin(), gc.end());
+ inExprs.insert(inExprs.begin() + channelPos - 1, gc.begin(), gc.end());
kExprs.insert(channelPos == 0 ? kExprs.begin()
- : kExprs.begin() + channelPos - 1,
+ : kExprs.begin() + channelPos - 2,
gfc.begin(), gfc.end());
- outExprs.insert(outExprs.begin() + channelPos, gf.begin(), gf.end());
+ outExprs.insert(outExprs.begin() + channelPos - 1, gf.begin(), gf.end());
SmallVector<AffineMap> maps(
{AffineMap::get(4 + 2 * numSpatial, 0, inExprs, ctx),
AffineMap::get(4 + 2 * numSpatial, 0, kExprs, ctx),
>From cce8517b88a98137d1e8a4d190c89c697702cd4c Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Mon, 1 Jan 2024 22:49:11 -0600
Subject: [PATCH 5/5] Implement layout attribute to generalize dim positions
(WIP)
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 9 +--
.../Dialect/Linalg/IR/LinalgInterfaces.td | 25 ++++++--
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 62 +++++++++++-------
.../mlir/Dialect/Utils/StructuredOpsUtils.td | 12 ++++
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 63 +++++++++++++------
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 +-
mlir/test/Dialect/Linalg/bufferize.mlir | 5 +-
mlir/test/Dialect/Linalg/named-ops.mlir | 4 +-
mlir/test/Dialect/Linalg/tile-conv.mlir | 4 +-
9 files changed, 128 insertions(+), 58 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index e2f24432c003b4..72f65c9e810c67 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -141,10 +141,11 @@ void print(LinalgOp op, OpAsmPrinter &p);
// Common implementations for GroupedConvolutionOpInterface
namespace grouped_convolution_impl {
int64_t getSpatialRank(GroupedConvolutionOpInterface op);
-ArrayAttr createCommonIndexingMaps(MLIRContext *ctx, int64_t numSpatial,
- int64_t channelPos,
- const SmallVectorImpl<int64_t> &strides,
- const SmallVectorImpl<int64_t> &dilations);
+ArrayAttr createCommonIndexingMaps(
+ MLIRContext *ctx, int64_t numSpatial,
+ const SmallVector<SmallVector<utils::GroupedConvDim>> &layouts,
+ const SmallVectorImpl<int64_t> &strides,
+ const SmallVectorImpl<int64_t> &dilations);
ArrayAttr getIteratorTypes(GroupedConvolutionOpInterface op);
} // namespace grouped_convolution_impl
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 3feadfa17a2e5f..5ae481a222e3c8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -203,9 +203,24 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }];
let methods = [
InterfaceMethod<[{
- Returns the channel position.
+ Returns the groups position for the input.
}],
- "int64_t", "getChannelPosition", (ins)
+ "SmallVector<SmallVector<::mlir::utils::GroupedConvDim>>", "getLayoutsEnums", (ins)
+ >,
+ InterfaceMethod<[{
+ Returns the groups position for the input.
+ }],
+ "int64_t", "getInputGroupsPosition", (ins)
+ >,
+ InterfaceMethod<[{
+ Returns the channel position for the input.
+ }],
+ "int64_t", "getInputChannelPosition", (ins)
+ >,
+ InterfaceMethod<[{
+ Returns the channel position for the output.
+ }],
+ "int64_t", "getOutputChannelPosition", (ins)
>,
InterfaceMethod<[{
Get number of groups.
@@ -213,7 +228,7 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
"int64_t", "getNumGroups", (ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
- return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getChannelPosition() - 1];
+ return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputGroupsPosition() - 1];
}]>,
InterfaceMethod<[{
Get number of input channels.
@@ -221,7 +236,7 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
"int64_t", "getNumInputChannels", (ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
- return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getChannelPosition()];
+ return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputChannelPosition()];
}]>,
InterfaceMethod<[{
Get number of output channels.
@@ -229,7 +244,7 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
"int64_t", "getNumOutputChannels", (ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
- return cast<ShapedType>($_op.getDpsInits()[0].getType()).getShape()[$_op.getChannelPosition()];
+ return cast<ShapedType>($_op.getDpsInits()[0].getType()).getShape()[$_op.getOutputChannelPosition()];
}]>,
InterfaceMethod<[{
Returns indexing maps for any spatial dimension.
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index bc7e7ba004c9b1..fcfd9f61aa75e4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -396,29 +396,23 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
}];
let description = [{
Allows any number of spatial dimensions but treats all of them as contiguous. Throughout, `S`,
- will represent all spatial dimensions. Operand layouts are determined by the `channel_first`
- `bool` attritbute. When placing the channel dim first or last, the batch dim is excluded. In
- any case, the channel and spatial dims are in the same relative order for all operands.
+ will represent all spatial dimensions. Operand layouts are determined by the `layouts`
+ `StrArrayAttr` attritbute. Each element of the array is a string representing the layout of the
+ corresponding operand and should be be mappable to a `GroupedConvDim` enum, i.e. one of
+ n: (batch dim)
+ g: (group dim)
+ f: (feature or output channel dim)
+ s: (all spatial dims)
+ c: (input channel dim).
- Domain: N, G, F, S, C, KS
-
- Layouts:
- `channel_first == true`:
- Input: `NGCS`
- Kernel: `FS`
- Output: `NGFS`
-
- `channel_first == false`:
- Input: `NSGC`
- Kernel: `SGFC`
- Output: `NSGF`
+ The domain will always be in the order `(N, G, F, S, C, KS)`.
}];
let arguments = (ins
Variadic<TensorOrMemref>:$inputs,
Variadic<TensorOrMemref>:$inits,
- DefaultValuedAttr<BoolAttr, "true">:$channel_first,
+ DefaultValuedAttr<StrArrayAttr, "{\"ngcs\", \"gfcs\", \"ngfs\"}">:$layouts,
OptionalAttr<I64ElementsAttr>:$strides,
OptionalAttr<I64ElementsAttr>:$dilations
);
@@ -428,11 +422,10 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<
- (ins "Value":$input, "Value":$filter, "Value":$init, CArg<"bool", "true">:$channel_first,
+ (ins "Value":$input, "Value":$filter, "Value":$init,
CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
- $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
int64_t numSpatialDims = input.getType().cast<ShapedType>().getRank() - 3;
if (strides.empty())
strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
@@ -449,11 +442,10 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter,
- "Value":$init, CArg<"bool", "true">:$channel_first,
+ "Value":$init,
CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
- $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
int64_t numSpatialDims = input.getType().cast<ShapedType>().getRank() - 3;
if (strides.empty())
strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
@@ -470,10 +462,9 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter,
- "Value":$init, "Attribute":$channel_first, "Attribute":$strides, "Attribute":$dilations,
+ "Value":$init, "Attribute":$strides, "Attribute":$dilations,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
- $_state.addAttribute(getChannelFirstAttrName($_state.name), channel_first);
$_state.addAttribute(getStridesAttrName($_state.name), strides);
$_state.addAttribute(getDilationsAttrName($_state.name), dilations);
buildStructuredOp($_builder, $_state, resultTensorTypes, {input, filter}, init,
@@ -506,8 +497,31 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
return detail::grouped_convolution_impl::getSpatialRank(*this);
}
- int64_t getChannelPosition() {
- return (getChannelFirstAttr().getValue()) ? 2 : getSpatialRank() + 2;
+ SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> getLayoutsEnums() {
+ SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> layouts;
+ for (auto attr : (*this).getLayoutsAttr().getValue()) {
+ std::string layoutStr = cast<StringAttr>(attr).getValue().str();
+ SmallVector<::mlir::utils::GroupedConvDim> layout(layoutStr.size());
+ for (size_t i = 0; i < layoutStr.size(); i++) {
+ auto maybeDimEnum = ::mlir::utils::symbolizeGroupedConvDim(layoutStr.substr(i, 1).c_str());
+ assert(maybeDimEnum);
+ layout[i] = maybeDimEnum.value();
+ }
+ layouts.push_back(layout);
+ }
+ return layouts;
+ }
+
+ int64_t getOutputChannelPosition() {
+ return 2;
+ }
+
+ int64_t getInputChannelPosition() {
+ return 2;
+ }
+
+ int64_t getInputGroupsPosition() {
+ return 1;
}
}];
}
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td
index 4200343ce3e132..c7c5d617f6492c 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td
@@ -20,4 +20,16 @@ def IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
let cppNamespace = "::mlir::utils";
}
+def GroupedConvDim : I32EnumAttr<"GroupedConvDim", "Convolution dim",
+ [
+ I32EnumAttrCase<"n", 0>, // batch
+ I32EnumAttrCase<"g", 1>, // group
+ I32EnumAttrCase<"f", 2>, // feature (output channel)
+ I32EnumAttrCase<"s", 3>, // spatial
+ I32EnumAttrCase<"c", 4> // channel (input channel)
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::utils";
+}
+
#endif // STRUCTURED_OPS_UTILS
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index c736bd064bded2..b98de5ee259e45 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -679,9 +679,11 @@ ArrayAttr mlir::linalg::detail::grouped_convolution_impl::getIteratorTypes(
ArrayAttr
mlir::linalg::detail::grouped_convolution_impl::createCommonIndexingMaps(
- MLIRContext *ctx, int64_t numSpatial, int64_t channelPos,
+ MLIRContext *ctx, int64_t numSpatial,
+ const SmallVector<SmallVector<utils::GroupedConvDim>> &layouts,
const SmallVectorImpl<int64_t> &strides,
const SmallVectorImpl<int64_t> &dilations) {
+ assert(layouts.size() == 3 && "expected 3 layouts: image, filter, init");
// Domain: (n, g, f, os, c, ks)
AffineExpr n = getAffineDimExpr(0, ctx);
@@ -695,26 +697,51 @@ mlir::linalg::detail::grouped_convolution_impl::createCommonIndexingMaps(
llvm::seq<int64_t>(numSpatial + 4, 2 * (numSpatial + 1) + 2),
[&](int64_t d) { return getAffineDimExpr(d, ctx); }));
- // Initialze operand accesses in nw order and insert c according to channel
- // position
- SmallVector<AffineExpr> inExprs = {n}, outExprs = {n};
- SmallVector<AffineExpr> gc = {g, c};
- SmallVector<AffineExpr> gf = {g, f};
- SmallVector<AffineExpr> gfc = {g, f, c};
+ SmallVector<AffineExpr> inSpatials;
+ inSpatials.reserve(numSpatial);
for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, strides, dilations)) {
- inExprs.push_back(sp * st + ksp * di);
- outExprs.push_back(sp);
+ inSpatials.push_back(sp * st + ksp * di);
}
- SmallVector<AffineExpr> kExprs(ks);
- inExprs.insert(inExprs.begin() + channelPos - 1, gc.begin(), gc.end());
- kExprs.insert(channelPos == 0 ? kExprs.begin()
- : kExprs.begin() + channelPos - 2,
- gfc.begin(), gfc.end());
- outExprs.insert(outExprs.begin() + channelPos - 1, gf.begin(), gf.end());
+
+ auto getExprs = [&](const SmallVector<utils::GroupedConvDim> &layout,
+ const SmallVector<AffineExpr> &spatials) {
+ SmallVector<AffineExpr> exprs(layout.size());
+ int64_t spatialDim;
+ for (const auto &[i, dim] : llvm::enumerate(layout)) {
+ switch (dim) {
+ case utils::GroupedConvDim::n:
+ exprs[i] = n;
+ break;
+ case utils::GroupedConvDim::g:
+ exprs[i] = g;
+ break;
+ case utils::GroupedConvDim::f:
+ exprs[i] = f;
+ break;
+ case utils::GroupedConvDim::s:
+ exprs[i] = spatials[0];
+ spatialDim = i;
+ break;
+ case utils::GroupedConvDim::c:
+ exprs[i] = c;
+ break;
+ default:
+ assert(false);
+ }
+ }
+ if (spatials.size() > 1)
+ exprs.insert(exprs.begin() + spatialDim + 1, spatials.begin() + 1,
+ spatials.end());
+ return exprs;
+ };
+ SmallVector<AffineExpr> inExprs = getExprs(layouts[0], inSpatials);
+ SmallVector<AffineExpr> kExprs = getExprs(layouts[1], ks);
+ SmallVector<AffineExpr> outExprs = getExprs(layouts[2], s);
SmallVector<AffineMap> maps(
- {AffineMap::get(4 + 2 * numSpatial, 0, inExprs, ctx),
- AffineMap::get(4 + 2 * numSpatial, 0, kExprs, ctx),
- AffineMap::get(4 + 2 * numSpatial, 0, outExprs, ctx)});
+ {AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[0], inSpatials),
+ ctx),
+ AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[1], ks), ctx),
+ AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[2], s), ctx)});
return Builder(ctx).getAffineMapArrayAttr(maps);
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c31b17082b8900..29a3f39c8696c9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1761,7 +1761,7 @@ ArrayAttr GroupedConvNDOp::getIndexingMaps() {
return cached;
cached = detail::grouped_convolution_impl::createCommonIndexingMaps(
- getContext(), getSpatialRank(), getChannelPosition(), getStridesVector(),
+ getContext(), getSpatialRank(), getLayoutsEnums(), getStridesVector(),
getDilationsVector());
(*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 9d3444fe2ce9cb..876fdc9b11dc27 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -232,10 +232,11 @@ func.func public @main(%arg0: tensor<2x3xi1>) -> tensor<6xi64> {
// CHECK-DAG: %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<64x2x20x8x8x8xf32>
// CHECK: memref.copy %[[ARG2_MEMREF]], %[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32> to memref<64x2x20x8x8x8xf32>
// CHECK: linalg.grouped_conv_nd
-// CHECK-SAME: {channel_first = true, dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64>}
+// CHECK-SAME: dilations = dense<2> : tensor<3xi64>
+// CHECK-SAME: strides = dense<3> : tensor<3xi64>}
// CHECK-SAME: ins(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]] : memref<64x2x16x26x26x26xf32>, memref<2x20x16x3x3x3xf32>)
// CHECK-SAME: outs(%[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32>)
func.func @gen_grouped_3D_channel_first_tensor(%arg0: tensor<64x2x16x26x26x26xf32>, %arg1: tensor<2x20x16x3x3x3xf32>, %arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> {
- %0 = linalg.grouped_conv_nd {channel_first = true, strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x2x16x26x26x26xf32>, tensor<2x20x16x3x3x3xf32>) outs(%arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32>
+ %0 = linalg.grouped_conv_nd {strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x2x16x26x26x26xf32>, tensor<2x20x16x3x3x3xf32>) outs(%arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32>
return %0 : tensor<64x2x20x8x8x8xf32>
}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index bd728edd1ec715..24177a3a8d7fa6 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -4,8 +4,8 @@
// CHECK-LABEL: func @gen_grouped_1D_channel_first_memref
func.func @gen_grouped_1D_channel_first_memref(%arg0: memref<64x8x16x10xf32>, %arg1: memref<8x32x16x3xf32>, %arg2: memref<64x8x32x8xf32>) {
- // CHECK: grouped_conv_nd {{.*}}channel_first = true
- linalg.grouped_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x8x16x10xf32>, memref<8x32x16x3xf32>) outs(%arg2: memref<64x8x32x8xf32>)
+ // CHECK: grouped_conv_nd
+ linalg.grouped_conv_nd ins(%arg0, %arg1: memref<64x8x16x10xf32>, memref<8x32x16x3xf32>) outs(%arg2: memref<64x8x32x8xf32>)
return
}
diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir
index 50065ccb18cb84..475e2565ec5f94 100644
--- a/mlir/test/Dialect/Linalg/tile-conv.mlir
+++ b/mlir/test/Dialect/Linalg/tile-conv.mlir
@@ -52,7 +52,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0)[s0] -> (d0 + s0 - 1)>
func.func @grouped_conv_2D(%arg0 : memref<?x?x?x?x?xf32>, %arg1 : memref<?x?x?x?x?xf32>, %arg2 : memref<?x?x?x?x?xf32>) {
- linalg.grouped_conv_nd ins(%arg0, %arg1 : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>) outs(%arg2 : memref<?x?x?x?x?xf32>)
+ linalg.grouped_conv_nd {layouts = ["ngcs", "gfcs", "ngfs"]} ins(%arg0, %arg1 : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>) outs(%arg2 : memref<?x?x?x?x?xf32>)
return
}
@@ -98,6 +98,6 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[SVIN:.*]] = memref.subview %[[ARG0]][%[[I]], %[[J]], 0, %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[IN_CHANNELS]], %[[T9]], %[[T10]]]
// CHECK-DAG: %[[SVKER:.*]] = memref.subview %[[ARG1]][%[[J]], %[[K]], 0, 0, 0] [%[[T5]], %[[T6]], %[[IN_CHANNELS]], %[[KW]], %[[KH]]]
// CHECK-DAG: %[[SVOUT:.*]] = memref.subview %[[ARG2]][%[[I]], %[[J]], %[[K]], %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[T6]], %[[T7]], %[[T8]]]
-// CHECK: linalg.grouped_conv_nd {channel_first = true}
+// CHECK: linalg.grouped_conv_nd {layouts = ["ngcs", "gfcs", "ngfs"]}
// CHECK-SAME: ins(%[[SVIN]], %[[SVKER]]
// CHECK-SAME: outs(%[[SVOUT]]
\ No newline at end of file
More information about the llvm-commits
mailing list