[mlir] [clang-tools-extra] [llvm] Implement grouped conv interface (PR #80870)

via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 6 08:40:11 PST 2024


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/80870

>From ed244c7cbd95294077fa603e022ac234aaf19aa2 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 30 Dec 2023 13:51:52 -0600
Subject: [PATCH 1/5] Implement GroupedConvolutionOpInterface

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |  36 +++++
 .../Dialect/Linalg/IR/LinalgInterfaces.td     |  51 +++++++
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 127 ++++++++++++++++++
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 102 ++++++++++++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 105 +++++++++++++++
 mlir/test/Dialect/Linalg/loops.mlir           |   8 +-
 mlir/test/Dialect/Linalg/named-ops.mlir       |  11 ++
 mlir/test/Dialect/Linalg/tile-conv.mlir       |   2 +-
 8 files changed, 436 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 6c8240267e7d05..e2f24432c003b4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -28,6 +28,8 @@ namespace mlir {
 namespace linalg {
 class IteratorTypeAttr;
 class LinalgOp;
+class ConvolutionOpInterface;
+class GroupedConvolutionOpInterface;
 
 namespace detail {
 /// Implementation of the method that check if given operands
@@ -115,6 +117,37 @@ bool isaCopyOpInterface(LinalgOp linalgOp);
 
 namespace detail {
 
+// Common implementations for ConvolutionOpInterface
+namespace convolution_impl {
+// Returns strides as a vector.
+SmallVector<int64_t, 2> getStrides(ConvolutionOpInterface op);
+// Returns dilations as a vector.
+SmallVector<int64_t, 2> getDilations(ConvolutionOpInterface op);
+// Region builder for basic convolution
+void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                   ArrayRef<NamedAttribute> attrs);
+// Region builder for basic quantized convolution
+void quantizedRegionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                            ArrayRef<NamedAttribute> attrs);
+void getEffects(
+    Operation *op,
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects);
+ParseResult parse(OpAsmParser &parser, OperationState &result,
+                  bool isQuantized = false);
+void print(LinalgOp op, OpAsmPrinter &p);
+} // namespace convolution_impl
+
+// Common implementations for GroupedConvolutionOpInterface
+namespace grouped_convolution_impl {
+int64_t getSpatialRank(GroupedConvolutionOpInterface op);
+ArrayAttr createCommonIndexingMaps(MLIRContext *ctx, int64_t numSpatial,
+                                   int64_t channelPos,
+                                   const SmallVectorImpl<int64_t> &strides,
+                                   const SmallVectorImpl<int64_t> &dilations);
+ArrayAttr getIteratorTypes(GroupedConvolutionOpInterface op);
+} // namespace grouped_convolution_impl
+
 /// Returns true if the block contains a contraction of the following form:
 ///
 ///   %0 = <elemwise>(permutation-of(cu(block-argument-0),
@@ -171,6 +204,9 @@ LogicalResult verifyContractionInterface(Operation *op);
 /// Verify that `op` conforms to the ConvolutionOpInterface.
 LogicalResult verifyConvolutionInterface(Operation *op);
 
+/// Verify that `op` conforms to the GroupedConvolutionOpInterface.
+LogicalResult verifyGroupedConvolutionInterface(Operation *op);
+
 /// Verify that `op` conforms to the FillOpInterface.
 LogicalResult verifyFillInterface(Operation *op);
 
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9b..170ebf9d43030f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -175,6 +175,57 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
         return $_op.getOperation()->getOperand(1);
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/"Return the spatial rank.",
+      /*retTy=*/"int64_t",
+      /*methodName=*/"getSpatialRank",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        // Most convolution's inputs have batch, channel and spatial dims
+        return cast<ShapedType>(image().getType()).getRank() - 2;
+      }]
+    >
+  ];
+}
+
+def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInterface", [
+  LinalgConvolutionOpInterface]> {
+  let description = [{
+    A grouped convolution is defined in general terms:
+    1. It is a convolution as defined by `ConvolutionOpInterface`.
+    2. Operands have a the following distinct dimensions (excluding batch in input/output): group, channel, spatial
+    3. `input_rank == kernel_rank == output_rank` (including batch in input/output)
+    4. Reductions are along the input channel and spatial dimensions while group, output channel
+       and output spatial dimensions are parallel.
+  }];
+  let cppNamespace = "::mlir::linalg";
+  let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }];
+  let methods = [
+    InterfaceMethod<[{
+      Returns indexing maps for any spatial dimension.
+    }],
+    "::mlir::ArrayAttr", "getIteratorTypes", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return detail::grouped_convolution_impl::getIteratorTypes($_op);
+      }]>,
+    InterfaceMethod<[{
+      Returns strides.
+    }],
+    "::llvm::SmallVector<int64_t, 2>", "getStridesVector", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return detail::convolution_impl::getStrides($_op);
+    }]>,
+    InterfaceMethod<[{
+      Returns dilations.
+    }],
+    "::llvm::SmallVector<int64_t, 2>", "getDilationsVector", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return detail::convolution_impl::getDilations($_op);
+    }]>
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 751edd02288301..44db786a64595e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -384,6 +384,133 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// GroupedConvNDOp ops.
+//===----------------------------------------------------------------------===//
+
+def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
+  [AttrSizedOperandSegments, LinalgGroupedConvolutionOpInterface]> {
+    
+  let summary = [{
+    Performs N-D grouped convolution with switchable channel position; either first or last.
+  }];
+  let description = [{
+    Allows any number of spatial dimensions but treats all of them as contiguous.  Throughout, `S`,
+    will represent all spatial dimensions.  Operand layouts are determined by the `channel_first`
+    `bool` attritbute.  When placing the channel dim first or last, the batch dim is excluded.  In
+    any case, the channel and spatial dims are in the same relative order for all operands.
+    
+    Domain: N, G, F, S, C, KS
+
+    Layouts:
+      `channel_first == true`:
+        Input: `NGCS`
+        Kernel: `FS`
+        Output: `NGFS`
+
+      `channel_first == false`:
+        Input: `NSGC`
+        Kernel: `SGFC`
+        Output: `NSGF`
+
+  }];
+
+    let arguments = (ins
+      Variadic<TensorOrMemref>:$inputs,
+      Variadic<TensorOrMemref>:$inits,
+      DefaultValuedAttr<BoolAttr, "true">:$channel_first,
+      OptionalAttr<I64ElementsAttr>:$strides,
+      OptionalAttr<I64ElementsAttr>:$dilations
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "Value":$input, "Value":$filter, "Value":$init, CArg<"bool", "true">:$channel_first,
+            CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
+        int64_t numSpatialDims = input.getType().cast<ShapedType>().getRank() - 3;
+        if (strides.empty())
+          strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+        if (dilations.empty())
+          dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+        $_state.addAttribute(getStridesAttrName($_state.name),
+          ::mlir::DenseElementsAttr::get(
+            ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides));
+        $_state.addAttribute(getDilationsAttrName($_state.name),
+          ::mlir::DenseElementsAttr::get(
+            ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations));
+        buildStructuredOp($_builder, $_state, std::nullopt, {input, filter}, init,
+          attributes, GroupedConvNDOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter, 
+            "Value":$init, CArg<"bool", "true">:$channel_first,
+            CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
+        int64_t numSpatialDims = input.getType().cast<ShapedType>().getRank() - 3;
+        if (strides.empty())
+          strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+        if (dilations.empty())
+          dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+        $_state.addAttribute(getStridesAttrName($_state.name),
+          ::mlir::DenseElementsAttr::get(
+            ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides));
+        $_state.addAttribute(getDilationsAttrName($_state.name),
+          ::mlir::DenseElementsAttr::get(
+            ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations));
+        buildStructuredOp($_builder, $_state, resultTensorTypes,
+          {input, filter}, init, attributes, GroupedConvNDOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter, 
+      "Value":$init, "Attribute":$channel_first, "Attribute":$strides, "Attribute":$dilations,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute(getChannelFirstAttrName($_state.name), channel_first);
+        $_state.addAttribute(getStridesAttrName($_state.name), strides);
+        $_state.addAttribute(getDilationsAttrName($_state.name), dilations);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, {input, filter}, init,
+          attributes, GroupedConvNDOp::getRegionBuilder());
+      }]>
+    ];
+
+    // TODO: Figure out how to move this to the interface
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      void print(::mlir::OpAsmPrinter &printer) {
+        return detail::convolution_impl::print(*this, printer);
+      }
+      static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
+                                        ::mlir::OperationState &result) {
+        return detail::convolution_impl::parse(parser, result);
+      }
+      static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+                                mlir::ArrayRef<mlir::NamedAttribute>)>
+      getRegionBuilder() {
+        return detail::convolution_impl::regionBuilder;
+      }
+      // Implement functions necessary for DestinationStyleOpInterface.
+      MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+
+      // Implement functions necessary for LinalgOp.
+      ArrayAttr getIndexingMaps();
+
+      // Implement functions necessary for GroupedConvolutionOpInterface
+      int64_t getSpatialRank() {
+        return detail::grouped_convolution_impl::getSpatialRank(*this);
+      }
+
+      int64_t getChannelPosition() {
+          return (getChannelFirstAttr().getValue()) ? 1 : getSpatialRank() + 1;
+      }
+    }];
+}
 
 //===----------------------------------------------------------------------===//
 // Transpose op.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ba419d32f22a3e..ba28c4ed954970 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -638,6 +638,108 @@ enum class MatchConvolutionResult {
 };
 } // namespace mlir::linalg::detail
 
+SmallVector<int64_t, 2>
+mlir::linalg::detail::convolution_impl::getStrides(ConvolutionOpInterface op) {
+  auto maybeStridesAttr = op->getAttrOfType<DenseIntElementsAttr>("strides");
+  if (!maybeStridesAttr) {
+    OpBuilder builder(op.getContext());
+    return SmallVector<int64_t, 2>(op.getSpatialRank(), 1);
+  }
+  return llvm::to_vector(maybeStridesAttr.getValues<int64_t>());
+}
+
+SmallVector<int64_t, 2> mlir::linalg::detail::convolution_impl::getDilations(
+    ConvolutionOpInterface op) {
+  auto maybeDilationsAttr =
+      op->getAttrOfType<DenseIntElementsAttr>("dilations");
+  if (!maybeDilationsAttr) {
+    OpBuilder builder(op.getContext());
+    return SmallVector<int64_t, 2>(op.getSpatialRank(), 1);
+  }
+  return llvm::to_vector(maybeDilationsAttr.getValues<int64_t>());
+}
+
+int64_t mlir::linalg::detail::grouped_convolution_impl::getSpatialRank(
+    GroupedConvolutionOpInterface op) {
+  return cast<ShapedType>(op.image().getType()).getRank() - 3;
+}
+
+ArrayAttr mlir::linalg::detail::grouped_convolution_impl::getIteratorTypes(
+    GroupedConvolutionOpInterface op) {
+  int64_t numSpatialDims = op.getSpatialRank();
+  SmallVector<Attribute> iteratorTypes(
+      3 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par));
+  SmallVector<Attribute> reductions(
+      numSpatialDims + 1, IteratorTypeAttr::get(op.getContext(), red));
+  iteratorTypes.insert(iteratorTypes.end(), reductions.begin(),
+                       reductions.end());
+
+  return Builder(op.getContext()).getArrayAttr(iteratorTypes);
+}
+
+ArrayAttr
+mlir::linalg::detail::grouped_convolution_impl::createCommonIndexingMaps(
+    MLIRContext *ctx, int64_t numSpatial, int64_t channelPos,
+    const SmallVectorImpl<int64_t> &strides,
+    const SmallVectorImpl<int64_t> &dilations) {
+
+  // Domain: (n, g, f, os, c, ks)
+  AffineExpr n = getAffineDimExpr(0, ctx);
+  AffineExpr g = getAffineDimExpr(1, ctx);
+  AffineExpr f = getAffineDimExpr(2, ctx);
+  SmallVector<AffineExpr> s(
+      llvm::map_range(llvm::seq<int64_t>(3, numSpatial + 3),
+                      [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+  AffineExpr c = getAffineDimExpr(numSpatial + 3, ctx);
+  SmallVector<AffineExpr> ks(llvm::map_range(
+      llvm::seq<int64_t>(numSpatial + 4, 2 * (numSpatial + 1) + 2),
+      [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+
+  // Initialze operand accesses in nw order and insert c according to channel
+  // position
+  SmallVector<AffineExpr> inExprs = {n}, outExprs = {n};
+  SmallVector<AffineExpr> gc = {g, c};
+  SmallVector<AffineExpr> gf = {g, f};
+  SmallVector<AffineExpr> gfc = {g, f, c};
+  for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, strides, dilations)) {
+    inExprs.push_back(sp * st + ksp * di);
+    outExprs.push_back(sp);
+  }
+  SmallVector<AffineExpr> kExprs(ks);
+  inExprs.insert(inExprs.begin() + channelPos, gc.begin(), gc.end());
+  kExprs.insert(channelPos == 0 ? kExprs.begin()
+                                : kExprs.begin() + channelPos - 1,
+                gfc.begin(), gfc.end());
+  outExprs.insert(outExprs.begin() + channelPos, gf.begin(), gf.end());
+  SmallVector<AffineMap> maps(
+      {AffineMap::get(4 + 2 * numSpatial, 0, inExprs, ctx),
+       AffineMap::get(4 + 2 * numSpatial, 0, kExprs, ctx),
+       AffineMap::get(4 + 2 * numSpatial, 0, outExprs, ctx)});
+
+  return Builder(ctx).getAffineMapArrayAttr(maps);
+}
+
+LogicalResult
+mlir::linalg::detail::verifyGroupedConvolutionInterface(Operation *op) {
+  if (failed(verifyConvolutionInterface(op)))
+    return failure();
+  if (GroupedConvolutionOpInterface conv =
+          dyn_cast<GroupedConvolutionOpInterface>(op)) {
+    const auto imageType = conv.image().getType().dyn_cast<ShapedType>();
+    const auto imageRank = imageType.getRank();
+    const auto kernelRank =
+        conv.filter().getType().cast<ShapedType>().getRank();
+    const auto initType =
+        cast<LinalgOp>(op).getDpsInits()[0].getType().dyn_cast<ShapedType>();
+    const auto initRank = initType.getRank();
+    if (imageRank != kernelRank || imageRank != initRank)
+      return op->emitError(
+          "Rank relationship must be `in_rank == out_rank == kernel_rank`");
+    return success();
+  }
+  return failure();
+}
+
 mlir::linalg::detail::MatchConvolutionResult
 mlir::linalg::detail::isConvolutionInterfaceImpl(
     Operation *op, ConvolutionDimensions *dimensions) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b68aa77fd83a1c..c31b17082b8900 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1663,6 +1663,111 @@ LogicalResult ReduceOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ConvolutionOpInterface
+//===----------------------------------------------------------------------===//
+
+// There must be a way to avoid defining the following 3 functions
+ParseResult mlir::linalg::detail::convolution_impl::parse(
+    OpAsmParser &parser, OperationState &result, bool isQuantized) {
+  if (isQuantized)
+    return parseNamedStructuredOp(
+        parser, result, 5,
+        mlir::linalg::detail::convolution_impl::quantizedRegionBuilder);
+  return parseNamedStructuredOp(
+      parser, result, 3, mlir::linalg::detail::convolution_impl::regionBuilder);
+}
+
+void mlir::linalg::detail::convolution_impl::print(LinalgOp op,
+                                                   OpAsmPrinter &p) {
+  printNamedStructuredOp(p, op.getOperation(), op.getDpsInputs(),
+                         op.getDpsInits());
+}
+
+// Build {mul, add} region for convolution
+void mlir::linalg::detail::convolution_impl::regionBuilder(
+    ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "ConvolutionInterface regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+  SmallVector<Value> yields;
+
+  Value value1 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(0));
+  Value value2 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
+void mlir::linalg::detail::convolution_impl::quantizedRegionBuilder(
+    ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 5 &&
+         "ConvolutionInterface regionBuilder expects 5 args");
+  RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+  Value value1 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+                         block.getArgument(0));
+  Value value2 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+                         block.getArgument(2));
+  Value value3 = helper.buildBinaryFn(BinaryFn::sub, value1, value2);
+  Value value4 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+                         block.getArgument(1));
+  Value value5 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+                         block.getArgument(3));
+  Value value6 = helper.buildBinaryFn(BinaryFn::sub, value4, value5);
+  Value value7 = helper.buildBinaryFn(BinaryFn::mul, value3, value6);
+  Value value8 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(4), value7);
+  helper.yieldOutputs({value8});
+}
+
+void mlir::linalg::detail::convolution_impl::getEffects(
+    Operation *op,
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (!isa<ConvolutionOpInterface>(op))
+    return;
+  if (LinalgOp linalgOp = dyn_cast<LinalgOp>(op)) {
+    if (linalgOp.hasTensorSemantics())
+      return;
+    getGenericEffectsImpl(effects, linalgOp.getOperation()->getResults(),
+                          linalgOp.getDpsInputs(), linalgOp.getDpsInits());
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// GroupedConvNDOp
+//===----------------------------------------------------------------------===//
+
+void GroupedConvNDOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  return detail::convolution_impl::getEffects(*this, effects);
+}
+
+ArrayAttr GroupedConvNDOp::getIndexingMaps() {
+  ArrayAttr cached = (*this)->getAttrOfType<ArrayAttr>(
+      LinalgDialect::kMemoizedIndexingMapsAttrName);
+  if (cached)
+    return cached;
+
+  cached = detail::grouped_convolution_impl::createCommonIndexingMaps(
+      getContext(), getSpatialRank(), getChannelPosition(), getStridesVector(),
+      getDilationsVector());
+
+  (*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+  return cached;
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 8c13422fd63833..640680483130d7 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -1,12 +1,10 @@
-// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s
-// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefix=CHECKPARALLEL %s
+// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s --check-prefixes=COMMON,CHECK
+// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefixes=COMMON,CHECKPARALLEL %s
 
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
 // RUN: mlir-opt %s -convert-linalg-to-loops -test-lower-to-llvm -o=/dev/null 2>&1
 
-// CHECK: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
-
-// CHECKPARALLEL: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// COMMON: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
 
 func.func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
   %c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 29977a71dbb864..bd728edd1ec715 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1,5 +1,16 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
+// -----
+
+// CHECK-LABEL: func @gen_grouped_1D_channel_first_memref
+func.func @gen_grouped_1D_channel_first_memref(%arg0: memref<64x8x16x10xf32>, %arg1: memref<8x32x16x3xf32>, %arg2: memref<64x8x32x8xf32>) {
+  // CHECK: grouped_conv_nd {{.*}}channel_first = true
+    linalg.grouped_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x8x16x10xf32>, memref<8x32x16x3xf32>) outs(%arg2: memref<64x8x32x8xf32>)
+    return
+}
+
+// -----
+
 // CHECK-LABEL: func @depthwise_conv_1d_nwc_wcm
 func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> {
   %zero = arith.constant 0.000000e+00 : f32
diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir
index 4a940f12662e6c..1662f5c45fe804 100644
--- a/mlir/test/Dialect/Linalg/tile-conv.mlir
+++ b/mlir/test/Dialect/Linalg/tile-conv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -transform-interpreter -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter -canonicalize -split-input-file | FileCheck %s
 
 //  CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
 //  CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)>

>From dbd721546e2f2811f4ed6a118412c4cd198530d6 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 31 Dec 2023 09:58:37 -0600
Subject: [PATCH 2/5] Add bufferization test

---
 mlir/test/Dialect/Linalg/bufferize.mlir | 22 ++++++++++++++++++++++
 1 file changed, 22 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 29f27e6838e661..9d3444fe2ce9cb 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -217,3 +217,25 @@ func.func public @main(%arg0: tensor<2x3xi1>) -> tensor<6xi64> {
   %3 = func.call @csum(%2) : (tensor<6xi64>) -> tensor<6xi64>
   return %3 : tensor<6xi64>
 }
+
+
+
+// -----
+
+// CHECK-LABEL:   func @gen_grouped_3D_channel_first_tensor(
+// CHECK-SAME:                                   %[[ARG0_TENSOR:.*]]: tensor<64x2x16x26x26x26xf32>,
+// CHECK-SAME:                                   %[[ARG1_TENSOR:.*]]: tensor<2x20x16x3x3x3xf32>,
+// CHECK-SAME:                                   %[[ARG2_TENSOR:.*]]: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> {
+// CHECK-DAG:       %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0_TENSOR]] : memref<64x2x16x26x26x26xf32>
+// CHECK-DAG:       %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<2x20x16x3x3x3xf32>
+// CHECK-DAG:       %[[ARG2_MEMREF:.*]] = bufferization.to_memref %[[ARG2_TENSOR]] : memref<64x2x20x8x8x8xf32>
+// CHECK-DAG:       %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<64x2x20x8x8x8xf32>
+// CHECK:           memref.copy %[[ARG2_MEMREF]], %[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32> to memref<64x2x20x8x8x8xf32>
+// CHECK:           linalg.grouped_conv_nd
+// CHECK-SAME:      {channel_first = true, dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64>}
+// CHECK-SAME:      ins(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]] : memref<64x2x16x26x26x26xf32>, memref<2x20x16x3x3x3xf32>)
+// CHECK-SAME:      outs(%[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32>)
+func.func @gen_grouped_3D_channel_first_tensor(%arg0: tensor<64x2x16x26x26x26xf32>, %arg1: tensor<2x20x16x3x3x3xf32>, %arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> {
+    %0 = linalg.grouped_conv_nd {channel_first = true, strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x2x16x26x26x26xf32>, tensor<2x20x16x3x3x3xf32>) outs(%arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32>
+    return %0 : tensor<64x2x20x8x8x8xf32> 
+}

>From 162904694b84924261c6ddd1f9f4430df66f6a10 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 31 Dec 2023 10:46:45 -0600
Subject: [PATCH 3/5] Add tiling regression test

---
 mlir/test/Dialect/Linalg/tile-conv.mlir | 60 +++++++++++++++++++++++++
 1 file changed, 60 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir
index 1662f5c45fe804..50065ccb18cb84 100644
--- a/mlir/test/Dialect/Linalg/tile-conv.mlir
+++ b/mlir/test/Dialect/Linalg/tile-conv.mlir
@@ -41,3 +41,63 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:       linalg.conv_2d
 //  CHECK-SAME:         ins(%[[SVIN]], %[[SVKER]]
 //  CHECK-SAME:         outs(%[[SVOUT]]
+
+// -----
+
+//  CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
+//  CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)>
+//  CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
+//  CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+//  CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 6)>
+//  CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0)[s0] -> (d0 + s0 - 1)>
+
+func.func @grouped_conv_2D(%arg0 : memref<?x?x?x?x?xf32>, %arg1 : memref<?x?x?x?x?xf32>, %arg2 : memref<?x?x?x?x?xf32>) {
+  linalg.grouped_conv_nd ins(%arg0, %arg1 : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>) outs(%arg2 : memref<?x?x?x?x?xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.grouped_conv_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop:5 = transform.structured.tile_using_for %0 [2, 3, 4, 5, 6] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+//       CHECK: func @grouped_conv_2D
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?x?x?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?x?x?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?x?x?x?xf32>
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
+//   CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+//   CHECK-DAG:   %[[C6:.*]] = arith.constant 6 : index
+//   CHECK-DAG:   %[[BATCH:.*]] = memref.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[GROUPS:.*]] = memref.dim %[[ARG0]], %[[C1]]
+//   CHECK-DAG:   %[[IN_CHANNELS:.*]] = memref.dim %[[ARG0]], %[[C2]]
+//   CHECK-DAG:   %[[OUT_CHANNELS:.*]] = memref.dim %[[ARG1]], %[[C1]]
+//   CHECK-DAG:   %[[KW:.*]] = memref.dim %[[ARG1]], %[[C3]]
+//   CHECK-DAG:   %[[KH:.*]] = memref.dim %[[ARG1]], %[[C4]]
+//   CHECK-DAG:   %[[W:.*]] = memref.dim %[[ARG2]], %[[C3]]
+//   CHECK-DAG:   %[[H:.*]] = memref.dim %[[ARG2]], %[[C4]]
+//       CHECK:   scf.for %[[I:.*]] = %[[C0]] to %[[BATCH]] step %[[C2]]
+//       CHECK:     %[[T4:.*]] = affine.min #[[MAP0]](%[[I]])[%[[BATCH]]]
+//       CHECK:     scf.for %[[J:.*]] = %[[C0]] to %[[GROUPS]] step %[[C3]]
+//       CHECK:       %[[T5:.*]] = affine.min #[[MAP1]](%[[J]])[%[[GROUPS]]]
+//       CHECK:     scf.for %[[K:.*]] = %[[C0]] to %[[OUT_CHANNELS]] step %[[C4]]
+//   CHECK-DAG:       %[[T6:.*]] = affine.min #[[MAP2]](%[[K]])[%[[OUT_CHANNELS]]]
+//       CHECK:       scf.for %[[L:.*]] = %[[C0]] to %[[W]] step %[[C5]]
+//   CHECK-DAG:         %[[T7:.*]] = affine.min #[[MAP3]](%[[L]])[%[[W]]]
+//       CHECK:         scf.for %[[M:.*]] = %[[C0]] to %[[H]] step %[[C6]]
+//   CHECK-DAG:           %[[T8:.*]] = affine.min #[[MAP4]](%[[M]])[%[[H]]]
+//   CHECK-DAG:           %[[T9:.*]] = affine.apply #[[MAP5]](%[[T7]])[%[[KW]]]
+//   CHECK-DAG:           %[[T10:.*]] = affine.apply #[[MAP5]](%[[T8]])[%[[KH]]]
+//   CHECK-DAG:           %[[SVIN:.*]] = memref.subview %[[ARG0]][%[[I]], %[[J]], 0, %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[IN_CHANNELS]], %[[T9]], %[[T10]]]
+//   CHECK-DAG:           %[[SVKER:.*]] = memref.subview %[[ARG1]][%[[J]], %[[K]], 0, 0, 0] [%[[T5]], %[[T6]], %[[IN_CHANNELS]], %[[KW]], %[[KH]]]
+//   CHECK-DAG:           %[[SVOUT:.*]] = memref.subview %[[ARG2]][%[[I]], %[[J]], %[[K]], %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[T6]], %[[T7]], %[[T8]]]
+//       CHECK:           linalg.grouped_conv_nd {channel_first = true}
+//  CHECK-SAME:             ins(%[[SVIN]], %[[SVKER]]
+//  CHECK-SAME:             outs(%[[SVOUT]]
\ No newline at end of file

>From 1969a22b7b5b50dfeaf27a56237fdbaf99947ae2 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 31 Dec 2023 11:32:46 -0600
Subject: [PATCH 4/5] Add interface methods for getting channel and group sizes

---
 .../Dialect/Linalg/IR/LinalgInterfaces.td     | 29 +++++++++++++++++++
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  2 +-
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    |  6 ++--
 3 files changed, 33 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 170ebf9d43030f..3feadfa17a2e5f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -202,6 +202,35 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
   let cppNamespace = "::mlir::linalg";
   let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }];
   let methods = [
+    InterfaceMethod<[{
+      Returns the channel position.
+    }],
+    "int64_t", "getChannelPosition", (ins)
+    >,
+    InterfaceMethod<[{
+      Get number of groups. 
+    }],
+    "int64_t", "getNumGroups", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+      return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getChannelPosition() - 1];
+    }]>,
+    InterfaceMethod<[{
+      Get number of input channels. 
+    }],
+    "int64_t", "getNumInputChannels", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+      return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getChannelPosition()];
+    }]>,
+    InterfaceMethod<[{
+      Get number of output channels. 
+    }],
+    "int64_t", "getNumOutputChannels", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+      return cast<ShapedType>($_op.getDpsInits()[0].getType()).getShape()[$_op.getChannelPosition()];
+    }]>,
     InterfaceMethod<[{
       Returns indexing maps for any spatial dimension.
     }],
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 44db786a64595e..bc7e7ba004c9b1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -507,7 +507,7 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
       }
 
       int64_t getChannelPosition() {
-          return (getChannelFirstAttr().getValue()) ? 1 : getSpatialRank() + 1;
+          return (getChannelFirstAttr().getValue()) ? 2 : getSpatialRank() + 2;
       }
     }];
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ba28c4ed954970..c736bd064bded2 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -706,11 +706,11 @@ mlir::linalg::detail::grouped_convolution_impl::createCommonIndexingMaps(
     outExprs.push_back(sp);
   }
   SmallVector<AffineExpr> kExprs(ks);
-  inExprs.insert(inExprs.begin() + channelPos, gc.begin(), gc.end());
+  inExprs.insert(inExprs.begin() + channelPos - 1, gc.begin(), gc.end());
   kExprs.insert(channelPos == 0 ? kExprs.begin()
-                                : kExprs.begin() + channelPos - 1,
+                                : kExprs.begin() + channelPos - 2,
                 gfc.begin(), gfc.end());
-  outExprs.insert(outExprs.begin() + channelPos, gf.begin(), gf.end());
+  outExprs.insert(outExprs.begin() + channelPos - 1, gf.begin(), gf.end());
   SmallVector<AffineMap> maps(
       {AffineMap::get(4 + 2 * numSpatial, 0, inExprs, ctx),
        AffineMap::get(4 + 2 * numSpatial, 0, kExprs, ctx),

>From cce8517b88a98137d1e8a4d190c89c697702cd4c Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Mon, 1 Jan 2024 22:49:11 -0600
Subject: [PATCH 5/5] Implement layout attribute to generalize dim positions
 (WIP)

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |  9 +--
 .../Dialect/Linalg/IR/LinalgInterfaces.td     | 25 ++++++--
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 62 +++++++++++-------
 .../mlir/Dialect/Utils/StructuredOpsUtils.td  | 12 ++++
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 63 +++++++++++++------
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  2 +-
 mlir/test/Dialect/Linalg/bufferize.mlir       |  5 +-
 mlir/test/Dialect/Linalg/named-ops.mlir       |  4 +-
 mlir/test/Dialect/Linalg/tile-conv.mlir       |  4 +-
 9 files changed, 128 insertions(+), 58 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index e2f24432c003b4..72f65c9e810c67 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -141,10 +141,11 @@ void print(LinalgOp op, OpAsmPrinter &p);
 // Common implementations for GroupedConvolutionOpInterface
 namespace grouped_convolution_impl {
 int64_t getSpatialRank(GroupedConvolutionOpInterface op);
-ArrayAttr createCommonIndexingMaps(MLIRContext *ctx, int64_t numSpatial,
-                                   int64_t channelPos,
-                                   const SmallVectorImpl<int64_t> &strides,
-                                   const SmallVectorImpl<int64_t> &dilations);
+ArrayAttr createCommonIndexingMaps(
+    MLIRContext *ctx, int64_t numSpatial,
+    const SmallVector<SmallVector<utils::GroupedConvDim>> &layouts,
+    const SmallVectorImpl<int64_t> &strides,
+    const SmallVectorImpl<int64_t> &dilations);
 ArrayAttr getIteratorTypes(GroupedConvolutionOpInterface op);
 } // namespace grouped_convolution_impl
 
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 3feadfa17a2e5f..5ae481a222e3c8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -203,9 +203,24 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
   let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }];
   let methods = [
     InterfaceMethod<[{
-      Returns the channel position.
+      Returns the groups position for the input.
     }],
-    "int64_t", "getChannelPosition", (ins)
+    "SmallVector<SmallVector<::mlir::utils::GroupedConvDim>>", "getLayoutsEnums", (ins)
+    >,
+    InterfaceMethod<[{
+      Returns the groups position for the input.
+    }],
+    "int64_t", "getInputGroupsPosition", (ins)
+    >,
+    InterfaceMethod<[{
+      Returns the channel position for the input.
+    }],
+    "int64_t", "getInputChannelPosition", (ins)
+    >,
+    InterfaceMethod<[{
+      Returns the channel position for the output.
+    }],
+    "int64_t", "getOutputChannelPosition", (ins)
     >,
     InterfaceMethod<[{
       Get number of groups. 
@@ -213,7 +228,7 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
     "int64_t", "getNumGroups", (ins),
       /*methodBody=*/[{}],
       /*defaultImplementation=*/[{
-      return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getChannelPosition() - 1];
+      return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputGroupsPosition() - 1];
     }]>,
     InterfaceMethod<[{
       Get number of input channels. 
@@ -221,7 +236,7 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
     "int64_t", "getNumInputChannels", (ins),
       /*methodBody=*/[{}],
       /*defaultImplementation=*/[{
-      return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getChannelPosition()];
+      return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputChannelPosition()];
     }]>,
     InterfaceMethod<[{
       Get number of output channels. 
@@ -229,7 +244,7 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
     "int64_t", "getNumOutputChannels", (ins),
       /*methodBody=*/[{}],
       /*defaultImplementation=*/[{
-      return cast<ShapedType>($_op.getDpsInits()[0].getType()).getShape()[$_op.getChannelPosition()];
+      return cast<ShapedType>($_op.getDpsInits()[0].getType()).getShape()[$_op.getOutputChannelPosition()];
     }]>,
     InterfaceMethod<[{
       Returns indexing maps for any spatial dimension.
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index bc7e7ba004c9b1..fcfd9f61aa75e4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -396,29 +396,23 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
   }];
   let description = [{
     Allows any number of spatial dimensions but treats all of them as contiguous.  Throughout, `S`,
-    will represent all spatial dimensions.  Operand layouts are determined by the `channel_first`
-    `bool` attritbute.  When placing the channel dim first or last, the batch dim is excluded.  In
-    any case, the channel and spatial dims are in the same relative order for all operands.
+    will represent all spatial dimensions.  Operand layouts are determined by the `layouts`
+    `StrArrayAttr` attritbute.  Each element of the array is a string representing the layout of the
+    corresponding operand and should be be mappable to a `GroupedConvDim` enum, i.e. one of
+      n: (batch dim)
+      g: (group dim)
+      f: (feature or output channel dim)
+      s: (all spatial dims)
+      c: (input channel dim).
     
-    Domain: N, G, F, S, C, KS
-
-    Layouts:
-      `channel_first == true`:
-        Input: `NGCS`
-        Kernel: `FS`
-        Output: `NGFS`
-
-      `channel_first == false`:
-        Input: `NSGC`
-        Kernel: `SGFC`
-        Output: `NSGF`
+    The domain will always be in the order `(N, G, F, S, C, KS)`.
 
   }];
 
     let arguments = (ins
       Variadic<TensorOrMemref>:$inputs,
       Variadic<TensorOrMemref>:$inits,
-      DefaultValuedAttr<BoolAttr, "true">:$channel_first,
+      DefaultValuedAttr<StrArrayAttr, "{\"ngcs\", \"gfcs\", \"ngfs\"}">:$layouts,
       OptionalAttr<I64ElementsAttr>:$strides,
       OptionalAttr<I64ElementsAttr>:$dilations
     );
@@ -428,11 +422,10 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
     let skipDefaultBuilders = 1;
     let builders = [
       OpBuilder<
-      (ins "Value":$input, "Value":$filter, "Value":$init, CArg<"bool", "true">:$channel_first,
+      (ins "Value":$input, "Value":$filter, "Value":$init,
             CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
             CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
       [{
-        $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
         int64_t numSpatialDims = input.getType().cast<ShapedType>().getRank() - 3;
         if (strides.empty())
           strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
@@ -449,11 +442,10 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
       }]>,
       OpBuilder<
       (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter, 
-            "Value":$init, CArg<"bool", "true">:$channel_first,
+            "Value":$init,
             CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
             CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
       [{
-        $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
         int64_t numSpatialDims = input.getType().cast<ShapedType>().getRank() - 3;
         if (strides.empty())
           strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
@@ -470,10 +462,9 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
       }]>,
       OpBuilder<
       (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter, 
-      "Value":$init, "Attribute":$channel_first, "Attribute":$strides, "Attribute":$dilations,
+      "Value":$init, "Attribute":$strides, "Attribute":$dilations,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
       [{
-        $_state.addAttribute(getChannelFirstAttrName($_state.name), channel_first);
         $_state.addAttribute(getStridesAttrName($_state.name), strides);
         $_state.addAttribute(getDilationsAttrName($_state.name), dilations);
         buildStructuredOp($_builder, $_state, resultTensorTypes, {input, filter}, init,
@@ -506,8 +497,31 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
         return detail::grouped_convolution_impl::getSpatialRank(*this);
       }
 
-      int64_t getChannelPosition() {
-          return (getChannelFirstAttr().getValue()) ? 2 : getSpatialRank() + 2;
+      SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> getLayoutsEnums() {
+        SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> layouts;
+        for (auto attr : (*this).getLayoutsAttr().getValue()) {
+          std::string layoutStr = cast<StringAttr>(attr).getValue().str();
+          SmallVector<::mlir::utils::GroupedConvDim> layout(layoutStr.size());
+          for (size_t i = 0; i < layoutStr.size(); i++) {
+            auto maybeDimEnum = ::mlir::utils::symbolizeGroupedConvDim(layoutStr.substr(i, 1).c_str());
+            assert(maybeDimEnum);
+            layout[i] = maybeDimEnum.value(); 
+          }
+          layouts.push_back(layout);
+        }
+        return layouts;
+      }
+
+      int64_t getOutputChannelPosition() {
+          return 2;
+      }
+
+      int64_t getInputChannelPosition() {
+          return 2;
+      }
+
+      int64_t getInputGroupsPosition() {
+          return 1;
       }
     }];
 }
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td
index 4200343ce3e132..c7c5d617f6492c 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td
@@ -20,4 +20,16 @@ def IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
     let cppNamespace = "::mlir::utils";
 }
 
+def GroupedConvDim : I32EnumAttr<"GroupedConvDim", "Convolution dim",
+  [
+    I32EnumAttrCase<"n", 0>, // batch
+    I32EnumAttrCase<"g", 1>, // group
+    I32EnumAttrCase<"f", 2>, // feature (output channel)
+    I32EnumAttrCase<"s", 3>, // spatial
+    I32EnumAttrCase<"c", 4> // channel (input channel)
+  ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::utils";
+}
+
 #endif // STRUCTURED_OPS_UTILS
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index c736bd064bded2..b98de5ee259e45 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -679,9 +679,11 @@ ArrayAttr mlir::linalg::detail::grouped_convolution_impl::getIteratorTypes(
 
 ArrayAttr
 mlir::linalg::detail::grouped_convolution_impl::createCommonIndexingMaps(
-    MLIRContext *ctx, int64_t numSpatial, int64_t channelPos,
+    MLIRContext *ctx, int64_t numSpatial,
+    const SmallVector<SmallVector<utils::GroupedConvDim>> &layouts,
     const SmallVectorImpl<int64_t> &strides,
     const SmallVectorImpl<int64_t> &dilations) {
+  assert(layouts.size() == 3 && "expected 3 layouts: image, filter, init");
 
   // Domain: (n, g, f, os, c, ks)
   AffineExpr n = getAffineDimExpr(0, ctx);
@@ -695,26 +697,51 @@ mlir::linalg::detail::grouped_convolution_impl::createCommonIndexingMaps(
       llvm::seq<int64_t>(numSpatial + 4, 2 * (numSpatial + 1) + 2),
       [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
 
-  // Initialze operand accesses in nw order and insert c according to channel
-  // position
-  SmallVector<AffineExpr> inExprs = {n}, outExprs = {n};
-  SmallVector<AffineExpr> gc = {g, c};
-  SmallVector<AffineExpr> gf = {g, f};
-  SmallVector<AffineExpr> gfc = {g, f, c};
+  SmallVector<AffineExpr> inSpatials;
+  inSpatials.reserve(numSpatial);
   for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, strides, dilations)) {
-    inExprs.push_back(sp * st + ksp * di);
-    outExprs.push_back(sp);
+    inSpatials.push_back(sp * st + ksp * di);
   }
-  SmallVector<AffineExpr> kExprs(ks);
-  inExprs.insert(inExprs.begin() + channelPos - 1, gc.begin(), gc.end());
-  kExprs.insert(channelPos == 0 ? kExprs.begin()
-                                : kExprs.begin() + channelPos - 2,
-                gfc.begin(), gfc.end());
-  outExprs.insert(outExprs.begin() + channelPos - 1, gf.begin(), gf.end());
+
+  auto getExprs = [&](const SmallVector<utils::GroupedConvDim> &layout,
+                      const SmallVector<AffineExpr> &spatials) {
+    SmallVector<AffineExpr> exprs(layout.size());
+    int64_t spatialDim;
+    for (const auto &[i, dim] : llvm::enumerate(layout)) {
+      switch (dim) {
+      case utils::GroupedConvDim::n:
+        exprs[i] = n;
+        break;
+      case utils::GroupedConvDim::g:
+        exprs[i] = g;
+        break;
+      case utils::GroupedConvDim::f:
+        exprs[i] = f;
+        break;
+      case utils::GroupedConvDim::s:
+        exprs[i] = spatials[0];
+        spatialDim = i;
+        break;
+      case utils::GroupedConvDim::c:
+        exprs[i] = c;
+        break;
+      default:
+        assert(false);
+      }
+    }
+    if (spatials.size() > 1)
+      exprs.insert(exprs.begin() + spatialDim + 1, spatials.begin() + 1,
+                   spatials.end());
+    return exprs;
+  };
+  SmallVector<AffineExpr> inExprs = getExprs(layouts[0], inSpatials);
+  SmallVector<AffineExpr> kExprs = getExprs(layouts[1], ks);
+  SmallVector<AffineExpr> outExprs = getExprs(layouts[2], s);
   SmallVector<AffineMap> maps(
-      {AffineMap::get(4 + 2 * numSpatial, 0, inExprs, ctx),
-       AffineMap::get(4 + 2 * numSpatial, 0, kExprs, ctx),
-       AffineMap::get(4 + 2 * numSpatial, 0, outExprs, ctx)});
+      {AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[0], inSpatials),
+                      ctx),
+       AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[1], ks), ctx),
+       AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[2], s), ctx)});
 
   return Builder(ctx).getAffineMapArrayAttr(maps);
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c31b17082b8900..29a3f39c8696c9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1761,7 +1761,7 @@ ArrayAttr GroupedConvNDOp::getIndexingMaps() {
     return cached;
 
   cached = detail::grouped_convolution_impl::createCommonIndexingMaps(
-      getContext(), getSpatialRank(), getChannelPosition(), getStridesVector(),
+      getContext(), getSpatialRank(), getLayoutsEnums(), getStridesVector(),
       getDilationsVector());
 
   (*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 9d3444fe2ce9cb..876fdc9b11dc27 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -232,10 +232,11 @@ func.func public @main(%arg0: tensor<2x3xi1>) -> tensor<6xi64> {
 // CHECK-DAG:       %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<64x2x20x8x8x8xf32>
 // CHECK:           memref.copy %[[ARG2_MEMREF]], %[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32> to memref<64x2x20x8x8x8xf32>
 // CHECK:           linalg.grouped_conv_nd
-// CHECK-SAME:      {channel_first = true, dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64>}
+// CHECK-SAME:      dilations = dense<2> : tensor<3xi64>
+// CHECK-SAME:      strides = dense<3> : tensor<3xi64>}
 // CHECK-SAME:      ins(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]] : memref<64x2x16x26x26x26xf32>, memref<2x20x16x3x3x3xf32>)
 // CHECK-SAME:      outs(%[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32>)
 func.func @gen_grouped_3D_channel_first_tensor(%arg0: tensor<64x2x16x26x26x26xf32>, %arg1: tensor<2x20x16x3x3x3xf32>, %arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> {
-    %0 = linalg.grouped_conv_nd {channel_first = true, strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x2x16x26x26x26xf32>, tensor<2x20x16x3x3x3xf32>) outs(%arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32>
+    %0 = linalg.grouped_conv_nd {strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x2x16x26x26x26xf32>, tensor<2x20x16x3x3x3xf32>) outs(%arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32>
     return %0 : tensor<64x2x20x8x8x8xf32> 
 }
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index bd728edd1ec715..24177a3a8d7fa6 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -4,8 +4,8 @@
 
 // CHECK-LABEL: func @gen_grouped_1D_channel_first_memref
 func.func @gen_grouped_1D_channel_first_memref(%arg0: memref<64x8x16x10xf32>, %arg1: memref<8x32x16x3xf32>, %arg2: memref<64x8x32x8xf32>) {
-  // CHECK: grouped_conv_nd {{.*}}channel_first = true
-    linalg.grouped_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x8x16x10xf32>, memref<8x32x16x3xf32>) outs(%arg2: memref<64x8x32x8xf32>)
+  // CHECK: grouped_conv_nd 
+    linalg.grouped_conv_nd ins(%arg0, %arg1: memref<64x8x16x10xf32>, memref<8x32x16x3xf32>) outs(%arg2: memref<64x8x32x8xf32>)
     return
 }
 
diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir
index 50065ccb18cb84..475e2565ec5f94 100644
--- a/mlir/test/Dialect/Linalg/tile-conv.mlir
+++ b/mlir/test/Dialect/Linalg/tile-conv.mlir
@@ -52,7 +52,7 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0)[s0] -> (d0 + s0 - 1)>
 
 func.func @grouped_conv_2D(%arg0 : memref<?x?x?x?x?xf32>, %arg1 : memref<?x?x?x?x?xf32>, %arg2 : memref<?x?x?x?x?xf32>) {
-  linalg.grouped_conv_nd ins(%arg0, %arg1 : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>) outs(%arg2 : memref<?x?x?x?x?xf32>)
+  linalg.grouped_conv_nd {layouts = ["ngcs", "gfcs", "ngfs"]} ins(%arg0, %arg1 : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>) outs(%arg2 : memref<?x?x?x?x?xf32>)
   return
 }
 
@@ -98,6 +98,6 @@ module attributes {transform.with_named_sequence} {
 //   CHECK-DAG:           %[[SVIN:.*]] = memref.subview %[[ARG0]][%[[I]], %[[J]], 0, %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[IN_CHANNELS]], %[[T9]], %[[T10]]]
 //   CHECK-DAG:           %[[SVKER:.*]] = memref.subview %[[ARG1]][%[[J]], %[[K]], 0, 0, 0] [%[[T5]], %[[T6]], %[[IN_CHANNELS]], %[[KW]], %[[KH]]]
 //   CHECK-DAG:           %[[SVOUT:.*]] = memref.subview %[[ARG2]][%[[I]], %[[J]], %[[K]], %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[T6]], %[[T7]], %[[T8]]]
-//       CHECK:           linalg.grouped_conv_nd {channel_first = true}
+//       CHECK:           linalg.grouped_conv_nd {layouts = ["ngcs", "gfcs", "ngfs"]}
 //  CHECK-SAME:             ins(%[[SVIN]], %[[SVKER]]
 //  CHECK-SAME:             outs(%[[SVOUT]]
\ No newline at end of file



More information about the llvm-commits mailing list