[Mlir-commits] [mlir] [mlir][linalg] Implement `LinalgGroupedConvolutionOpInterface` to unify grouped convs (PR #94796)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 1 13:36:41 PDT 2024


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/94796

>From b361e4d021dbde4d3c2c1fca6dd36030a7aa4376 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 7 Jun 2024 14:48:18 -0500
Subject: [PATCH 1/3] Implement `LinalgGroupedConvolutionOpInterface` to unify
 grouped convs

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |  37 +++++
 .../Dialect/Linalg/IR/LinalgInterfaces.td     |  95 ++++++++++++
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 141 ++++++++++++++++++
 .../mlir/Dialect/Utils/StructuredOpsUtils.td  |  12 ++
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 129 ++++++++++++++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 104 +++++++++++++
 mlir/test/Dialect/Linalg/bufferize.mlir       |  21 +++
 mlir/test/Dialect/Linalg/loops.mlir           |   8 +-
 mlir/test/Dialect/Linalg/named-ops.mlir       |  11 ++
 mlir/test/Dialect/Linalg/tile-conv.mlir       |  62 +++++++-
 10 files changed, 614 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 08afdf373f014..ed5b4ff2de4dc 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -28,6 +28,8 @@ namespace mlir {
 namespace linalg {
 class IteratorTypeAttr;
 class LinalgOp;
+class ConvolutionOpInterface;
+class GroupedConvolutionOpInterface;
 class GenericOp;
 
 namespace detail {
@@ -133,6 +135,38 @@ std::optional<Value> isaFillOpInterface(GenericOp genericOp);
 
 namespace detail {
 
+// Common implementations for ConvolutionOpInterface
+namespace convolution_impl {
+// Returns strides as a vector.
+SmallVector<int64_t, 2> getStrides(ConvolutionOpInterface op);
+// Returns dilations as a vector.
+SmallVector<int64_t, 2> getDilations(ConvolutionOpInterface op);
+// Region builder for basic convolution
+void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                   ArrayRef<NamedAttribute> attrs);
+// Region builder for basic quantized convolution
+void quantizedRegionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                            ArrayRef<NamedAttribute> attrs);
+void getEffects(
+    Operation *op,
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects);
+ParseResult parse(OpAsmParser &parser, OperationState &result,
+                  bool isQuantized = false);
+void print(LinalgOp op, OpAsmPrinter &p);
+} // namespace convolution_impl
+
+// Common implementations for GroupedConvolutionOpInterface
+namespace grouped_convolution_impl {
+int64_t getSpatialRank(GroupedConvolutionOpInterface op);
+ArrayAttr createCommonIndexingMaps(
+    MLIRContext *ctx, int64_t numSpatial,
+    const SmallVector<SmallVector<utils::GroupedConvDim>> &layouts,
+    const SmallVectorImpl<int64_t> &strides,
+    const SmallVectorImpl<int64_t> &dilations);
+ArrayAttr getIteratorTypes(GroupedConvolutionOpInterface op);
+} // namespace grouped_convolution_impl
+
 /// Returns true if the block contains a contraction of the following form:
 ///
 ///   %0 = <elemwise>(permutation-of(cu(block-argument-0),
@@ -189,6 +223,9 @@ LogicalResult verifyContractionInterface(Operation *op);
 /// Verify that `op` conforms to the ConvolutionOpInterface.
 LogicalResult verifyConvolutionInterface(Operation *op);
 
+/// Verify that `op` conforms to the GroupedConvolutionOpInterface.
+LogicalResult verifyGroupedConvolutionInterface(Operation *op);
+
 /// Verify that `op` conforms to the FillOpInterface.
 LogicalResult verifyFillInterface(Operation *op);
 
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9..5ae481a222e3c 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -175,6 +175,101 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
         return $_op.getOperation()->getOperand(1);
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/"Return the spatial rank.",
+      /*retTy=*/"int64_t",
+      /*methodName=*/"getSpatialRank",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        // Most convolution's inputs have batch, channel and spatial dims
+        return cast<ShapedType>(image().getType()).getRank() - 2;
+      }]
+    >
+  ];
+}
+
+def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInterface", [
+  LinalgConvolutionOpInterface]> {
+  let description = [{
+    A grouped convolution is defined in general terms:
+    1. It is a convolution as defined by `ConvolutionOpInterface`.
+    2. Operands have a the following distinct dimensions (excluding batch in input/output): group, channel, spatial
+    3. `input_rank == kernel_rank == output_rank` (including batch in input/output)
+    4. Reductions are along the input channel and spatial dimensions while group, output channel
+       and output spatial dimensions are parallel.
+  }];
+  let cppNamespace = "::mlir::linalg";
+  let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }];
+  let methods = [
+    InterfaceMethod<[{
+      Returns the groups position for the input.
+    }],
+    "SmallVector<SmallVector<::mlir::utils::GroupedConvDim>>", "getLayoutsEnums", (ins)
+    >,
+    InterfaceMethod<[{
+      Returns the groups position for the input.
+    }],
+    "int64_t", "getInputGroupsPosition", (ins)
+    >,
+    InterfaceMethod<[{
+      Returns the channel position for the input.
+    }],
+    "int64_t", "getInputChannelPosition", (ins)
+    >,
+    InterfaceMethod<[{
+      Returns the channel position for the output.
+    }],
+    "int64_t", "getOutputChannelPosition", (ins)
+    >,
+    InterfaceMethod<[{
+      Get number of groups. 
+    }],
+    "int64_t", "getNumGroups", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+      return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputGroupsPosition() - 1];
+    }]>,
+    InterfaceMethod<[{
+      Get number of input channels. 
+    }],
+    "int64_t", "getNumInputChannels", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+      return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputChannelPosition()];
+    }]>,
+    InterfaceMethod<[{
+      Get number of output channels. 
+    }],
+    "int64_t", "getNumOutputChannels", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+      return cast<ShapedType>($_op.getDpsInits()[0].getType()).getShape()[$_op.getOutputChannelPosition()];
+    }]>,
+    InterfaceMethod<[{
+      Returns indexing maps for any spatial dimension.
+    }],
+    "::mlir::ArrayAttr", "getIteratorTypes", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return detail::grouped_convolution_impl::getIteratorTypes($_op);
+      }]>,
+    InterfaceMethod<[{
+      Returns strides.
+    }],
+    "::llvm::SmallVector<int64_t, 2>", "getStridesVector", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return detail::convolution_impl::getStrides($_op);
+    }]>,
+    InterfaceMethod<[{
+      Returns dilations.
+    }],
+    "::llvm::SmallVector<int64_t, 2>", "getDilationsVector", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return detail::convolution_impl::getDilations($_op);
+    }]>
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ac61117c3d6e3..7db7c54a4ea09 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -384,6 +384,147 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// GroupedConvNDOp ops.
+//===----------------------------------------------------------------------===//
+
+def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
+  [AttrSizedOperandSegments, LinalgGroupedConvolutionOpInterface]> {
+    
+  let summary = [{
+    Performs N-D grouped convolution with switchable channel position; either first or last.
+  }];
+  let description = [{
+    Allows any number of spatial dimensions but treats all of them as contiguous.  Throughout, `S`,
+    will represent all spatial dimensions.  Operand layouts are determined by the `layouts`
+    `StrArrayAttr` attritbute.  Each element of the array is a string representing the layout of the
+    corresponding operand and should be be mappable to a `GroupedConvDim` enum, i.e. one of
+      n: (batch dim)
+      g: (group dim)
+      f: (feature or output channel dim)
+      s: (all spatial dims)
+      c: (input channel dim).
+    
+    The domain will always be in the order `(N, G, F, S, C, KS)`.
+
+  }];
+
+    let arguments = (ins
+      Variadic<TensorOrMemref>:$inputs,
+      Variadic<TensorOrMemref>:$inits,
+      DefaultValuedAttr<StrArrayAttr, "{\"ngcs\", \"gfcs\", \"ngfs\"}">:$layouts,
+      OptionalAttr<I64ElementsAttr>:$strides,
+      OptionalAttr<I64ElementsAttr>:$dilations
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "Value":$input, "Value":$filter, "Value":$init,
+            CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        int64_t numSpatialDims = cast<ShapedType>(input.getType()).getRank() - 3;
+        if (strides.empty())
+          strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+        if (dilations.empty())
+          dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+        $_state.addAttribute(getStridesAttrName($_state.name),
+          ::mlir::DenseElementsAttr::get(
+            ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides));
+        $_state.addAttribute(getDilationsAttrName($_state.name),
+          ::mlir::DenseElementsAttr::get(
+            ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations));
+        buildStructuredOp($_builder, $_state, std::nullopt, {input, filter}, init,
+          attributes, GroupedConvNDOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter, 
+            "Value":$init,
+            CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        int64_t numSpatialDims = cast<ShapedType>(input.getType()).getRank() - 3;
+        if (strides.empty())
+          strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+        if (dilations.empty())
+          dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
+        $_state.addAttribute(getStridesAttrName($_state.name),
+          ::mlir::DenseElementsAttr::get(
+            ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides));
+        $_state.addAttribute(getDilationsAttrName($_state.name),
+          ::mlir::DenseElementsAttr::get(
+            ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations));
+        buildStructuredOp($_builder, $_state, resultTensorTypes,
+          {input, filter}, init, attributes, GroupedConvNDOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter, 
+      "Value":$init, "Attribute":$strides, "Attribute":$dilations,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute(getStridesAttrName($_state.name), strides);
+        $_state.addAttribute(getDilationsAttrName($_state.name), dilations);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, {input, filter}, init,
+          attributes, GroupedConvNDOp::getRegionBuilder());
+      }]>
+    ];
+
+    // TODO: Figure out how to move this to the interface
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      void print(::mlir::OpAsmPrinter &printer) {
+        return detail::convolution_impl::print(*this, printer);
+      }
+      static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
+                                        ::mlir::OperationState &result) {
+        return detail::convolution_impl::parse(parser, result);
+      }
+      static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+                                mlir::ArrayRef<mlir::NamedAttribute>)>
+      getRegionBuilder() {
+        return detail::convolution_impl::regionBuilder;
+      }
+      // Implement functions necessary for DestinationStyleOpInterface.
+      MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+
+      // Implement functions necessary for LinalgOp.
+      ArrayAttr getIndexingMaps();
+
+      // Implement functions necessary for GroupedConvolutionOpInterface
+      int64_t getSpatialRank() {
+        return detail::grouped_convolution_impl::getSpatialRank(*this);
+      }
+
+      SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> getLayoutsEnums() {
+        SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> layouts;
+        for (auto attr : (*this).getLayoutsAttr().getValue()) {
+          std::string layoutStr = cast<StringAttr>(attr).getValue().str();
+          SmallVector<::mlir::utils::GroupedConvDim> layout(layoutStr.size());
+          for (size_t i = 0; i < layoutStr.size(); i++) {
+            auto maybeDimEnum = ::mlir::utils::symbolizeGroupedConvDim(layoutStr.substr(i, 1).c_str());
+            assert(maybeDimEnum);
+            layout[i] = maybeDimEnum.value(); 
+          }
+          layouts.push_back(layout);
+        }
+        return layouts;
+      }
+
+      int64_t getOutputChannelPosition() {
+          return 2;
+      }
+
+      int64_t getInputChannelPosition() {
+          return 2;
+      }
+
+      int64_t getInputGroupsPosition() {
+          return 1;
+      }
+    }];
+}
 
 //===----------------------------------------------------------------------===//
 // Transpose op.
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td
index 4200343ce3e13..c7c5d617f6492 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td
@@ -20,4 +20,16 @@ def IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
     let cppNamespace = "::mlir::utils";
 }
 
+def GroupedConvDim : I32EnumAttr<"GroupedConvDim", "Convolution dim",
+  [
+    I32EnumAttrCase<"n", 0>, // batch
+    I32EnumAttrCase<"g", 1>, // group
+    I32EnumAttrCase<"f", 2>, // feature (output channel)
+    I32EnumAttrCase<"s", 3>, // spatial
+    I32EnumAttrCase<"c", 4> // channel (input channel)
+  ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::utils";
+}
+
 #endif // STRUCTURED_OPS_UTILS
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f35ab3b856b4e..c2db6670e4167 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -766,6 +766,135 @@ enum class MatchConvolutionResult {
 };
 } // namespace mlir::linalg::detail
 
+SmallVector<int64_t, 2>
+mlir::linalg::detail::convolution_impl::getStrides(ConvolutionOpInterface op) {
+  auto maybeStridesAttr = op->getAttrOfType<DenseIntElementsAttr>("strides");
+  if (!maybeStridesAttr) {
+    OpBuilder builder(op.getContext());
+    return SmallVector<int64_t, 2>(op.getSpatialRank(), 1);
+  }
+  return llvm::to_vector(maybeStridesAttr.getValues<int64_t>());
+}
+
+SmallVector<int64_t, 2> mlir::linalg::detail::convolution_impl::getDilations(
+    ConvolutionOpInterface op) {
+  auto maybeDilationsAttr =
+      op->getAttrOfType<DenseIntElementsAttr>("dilations");
+  if (!maybeDilationsAttr) {
+    OpBuilder builder(op.getContext());
+    return SmallVector<int64_t, 2>(op.getSpatialRank(), 1);
+  }
+  return llvm::to_vector(maybeDilationsAttr.getValues<int64_t>());
+}
+
+int64_t mlir::linalg::detail::grouped_convolution_impl::getSpatialRank(
+    GroupedConvolutionOpInterface op) {
+  return cast<ShapedType>(op.image().getType()).getRank() - 3;
+}
+
+ArrayAttr mlir::linalg::detail::grouped_convolution_impl::getIteratorTypes(
+    GroupedConvolutionOpInterface op) {
+  int64_t numSpatialDims = op.getSpatialRank();
+  SmallVector<Attribute> iteratorTypes(
+      3 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par));
+  SmallVector<Attribute> reductions(
+      numSpatialDims + 1, IteratorTypeAttr::get(op.getContext(), red));
+  iteratorTypes.insert(iteratorTypes.end(), reductions.begin(),
+                       reductions.end());
+
+  return Builder(op.getContext()).getArrayAttr(iteratorTypes);
+}
+
+ArrayAttr
+mlir::linalg::detail::grouped_convolution_impl::createCommonIndexingMaps(
+    MLIRContext *ctx, int64_t numSpatial,
+    const SmallVector<SmallVector<utils::GroupedConvDim>> &layouts,
+    const SmallVectorImpl<int64_t> &strides,
+    const SmallVectorImpl<int64_t> &dilations) {
+  assert(layouts.size() == 3 && "expected 3 layouts: image, filter, init");
+
+  // Domain: (n, g, f, os, c, ks)
+  AffineExpr n = getAffineDimExpr(0, ctx);
+  AffineExpr g = getAffineDimExpr(1, ctx);
+  AffineExpr f = getAffineDimExpr(2, ctx);
+  SmallVector<AffineExpr> s(
+      llvm::map_range(llvm::seq<int64_t>(3, numSpatial + 3),
+                      [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+  AffineExpr c = getAffineDimExpr(numSpatial + 3, ctx);
+  SmallVector<AffineExpr> ks(llvm::map_range(
+      llvm::seq<int64_t>(numSpatial + 4, 2 * (numSpatial + 1) + 2),
+      [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+
+  SmallVector<AffineExpr> inSpatials;
+  inSpatials.reserve(numSpatial);
+  for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, strides, dilations)) {
+    inSpatials.push_back(sp * st + ksp * di);
+  }
+
+  auto getExprs = [&](const SmallVector<utils::GroupedConvDim> &layout,
+                      const SmallVector<AffineExpr> &spatials) {
+    SmallVector<AffineExpr> exprs(layout.size());
+    int64_t spatialDim;
+    for (const auto &[i, dim] : llvm::enumerate(layout)) {
+      switch (dim) {
+      case utils::GroupedConvDim::n:
+        exprs[i] = n;
+        break;
+      case utils::GroupedConvDim::g:
+        exprs[i] = g;
+        break;
+      case utils::GroupedConvDim::f:
+        exprs[i] = f;
+        break;
+      case utils::GroupedConvDim::s:
+        exprs[i] = spatials[0];
+        spatialDim = i;
+        break;
+      case utils::GroupedConvDim::c:
+        exprs[i] = c;
+        break;
+      default:
+        assert(false);
+      }
+    }
+    if (spatials.size() > 1)
+      exprs.insert(exprs.begin() + spatialDim + 1, spatials.begin() + 1,
+                   spatials.end());
+    return exprs;
+  };
+  SmallVector<AffineExpr> inExprs = getExprs(layouts[0], inSpatials);
+  SmallVector<AffineExpr> kExprs = getExprs(layouts[1], ks);
+  SmallVector<AffineExpr> outExprs = getExprs(layouts[2], s);
+  SmallVector<AffineMap> maps(
+      {AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[0], inSpatials),
+                      ctx),
+       AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[1], ks), ctx),
+       AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[2], s), ctx)});
+
+  return Builder(ctx).getAffineMapArrayAttr(maps);
+}
+
+LogicalResult
+mlir::linalg::detail::verifyGroupedConvolutionInterface(Operation *op) {
+  if (failed(verifyConvolutionInterface(op)))
+    return failure();
+  if (GroupedConvolutionOpInterface conv =
+          dyn_cast<GroupedConvolutionOpInterface>(op)) {
+    const auto imageType = conv.image().getType().dyn_cast<ShapedType>();
+    const auto imageRank = imageType.getRank();
+    const auto kernelRank =
+        conv.filter().getType().cast<ShapedType>().getRank();
+    const auto initType =
+        cast<LinalgOp>(op).getDpsInits()[0].getType().dyn_cast<ShapedType>();
+    const auto initRank = initType.getRank();
+    if (imageRank != kernelRank || imageRank != initRank)
+      return op->emitError(
+          "Rank relationship must be `in_rank == out_rank == kernel_rank`");
+    return success();
+  }
+  return failure();
+}
+
 mlir::linalg::detail::MatchConvolutionResult
 mlir::linalg::detail::isConvolutionInterfaceImpl(
     Operation *op, ConvolutionDimensions *dimensions) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b79afebfa8158..21cc22f034aa6 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1735,6 +1735,110 @@ LogicalResult ReduceOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ConvolutionOpInterface
+//===----------------------------------------------------------------------===//
+
+// There must be a way to avoid defining the following 3 functions
+ParseResult mlir::linalg::detail::convolution_impl::parse(
+    OpAsmParser &parser, OperationState &result, bool isQuantized) {
+  if (isQuantized)
+    return parseNamedStructuredOp(
+        parser, result, 5,
+        mlir::linalg::detail::convolution_impl::quantizedRegionBuilder);
+  return parseNamedStructuredOp(
+      parser, result, 3, mlir::linalg::detail::convolution_impl::regionBuilder);
+}
+
+void mlir::linalg::detail::convolution_impl::print(LinalgOp op,
+                                                   OpAsmPrinter &p) {
+  printNamedStructuredOp(p, op.getOperation(), op.getDpsInputs(),
+                         op.getDpsInits());
+}
+
+// Build {mul, add} region for convolution
+void mlir::linalg::detail::convolution_impl::regionBuilder(
+    ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "ConvolutionInterface regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  Value value1 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(0));
+  Value value2 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
+void mlir::linalg::detail::convolution_impl::quantizedRegionBuilder(
+    ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 5 &&
+         "ConvolutionInterface regionBuilder expects 5 args");
+  RegionBuilderHelper helper(b, block);
+  Value value1 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+                         block.getArgument(0));
+  Value value2 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+                         block.getArgument(2));
+  Value value3 = helper.buildBinaryFn(BinaryFn::sub, value1, value2);
+  Value value4 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+                         block.getArgument(1));
+  Value value5 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+                         block.getArgument(3));
+  Value value6 = helper.buildBinaryFn(BinaryFn::sub, value4, value5);
+  Value value7 = helper.buildBinaryFn(BinaryFn::mul, value3, value6);
+  Value value8 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(4), value7);
+  helper.yieldOutputs({value8});
+}
+
+void mlir::linalg::detail::convolution_impl::getEffects(
+    Operation *op,
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (!isa<ConvolutionOpInterface>(op))
+    return;
+  if (LinalgOp linalgOp = dyn_cast<LinalgOp>(op)) {
+    if (linalgOp.hasPureTensorSemantics())
+      return;
+    getGenericEffectsImpl(effects, linalgOp);
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// GroupedConvNDOp
+//===----------------------------------------------------------------------===//
+
+void GroupedConvNDOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  return detail::convolution_impl::getEffects(*this, effects);
+}
+
+ArrayAttr GroupedConvNDOp::getIndexingMaps() {
+  ArrayAttr cached = (*this)->getAttrOfType<ArrayAttr>(
+      LinalgDialect::kMemoizedIndexingMapsAttrName);
+  if (cached)
+    return cached;
+
+  cached = detail::grouped_convolution_impl::createCommonIndexingMaps(
+      getContext(), getSpatialRank(), getLayoutsEnums(), getStridesVector(),
+      getDilationsVector());
+
+  (*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+  return cached;
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index e8ab1184b1fd2..56fc39f5fc073 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -189,3 +189,24 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> {
   // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32>
   // CHECK: return %[[OUT_TENSOR]]
 }
+
+// -----
+
+// CHECK-LABEL:   func @gen_grouped_3D_channel_first_tensor(
+// CHECK-SAME:                                   %[[ARG0_TENSOR:.*]]: tensor<64x2x16x26x26x26xf32>,
+// CHECK-SAME:                                   %[[ARG1_TENSOR:.*]]: tensor<2x20x16x3x3x3xf32>,
+// CHECK-SAME:                                   %[[ARG2_TENSOR:.*]]: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> {
+// CHECK-DAG:       %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0_TENSOR]] : memref<64x2x16x26x26x26xf32>
+// CHECK-DAG:       %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<2x20x16x3x3x3xf32>
+// CHECK-DAG:       %[[ARG2_MEMREF:.*]] = bufferization.to_memref %[[ARG2_TENSOR]] : memref<64x2x20x8x8x8xf32>
+// CHECK-DAG:       %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<64x2x20x8x8x8xf32>
+// CHECK:           memref.copy %[[ARG2_MEMREF]], %[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32> to memref<64x2x20x8x8x8xf32>
+// CHECK:           linalg.grouped_conv_nd
+// CHECK-SAME:      dilations = dense<2> : tensor<3xi64>
+// CHECK-SAME:      strides = dense<3> : tensor<3xi64>}
+// CHECK-SAME:      ins(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]] : memref<64x2x16x26x26x26xf32>, memref<2x20x16x3x3x3xf32>)
+// CHECK-SAME:      outs(%[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32>)
+func.func @gen_grouped_3D_channel_first_tensor(%arg0: tensor<64x2x16x26x26x26xf32>, %arg1: tensor<2x20x16x3x3x3xf32>, %arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> {
+    %0 = linalg.grouped_conv_nd {strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x2x16x26x26x26xf32>, tensor<2x20x16x3x3x3xf32>) outs(%arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32>
+    return %0 : tensor<64x2x20x8x8x8xf32> 
+}
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index b818170a8e797..df89029e27d86 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -1,12 +1,10 @@
-// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s
-// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefix=CHECKPARALLEL %s
+// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s --check-prefixes=COMMON,CHECK
+// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefixes=COMMON,CHECKPARALLEL %s
 
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
 // RUN: mlir-opt %s -convert-linalg-to-loops -test-lower-to-llvm -o=/dev/null 2>&1
 
-// CHECK: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
-
-// CHECKPARALLEL: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// COMMON: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
 
 func.func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
   %c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 02ecbed232c8b..a231569672209 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1,5 +1,16 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
+// -----
+
+// CHECK-LABEL: func @gen_grouped_1D_channel_first_memref
+func.func @gen_grouped_1D_channel_first_memref(%arg0: memref<64x8x16x10xf32>, %arg1: memref<8x32x16x3xf32>, %arg2: memref<64x8x32x8xf32>) {
+  // CHECK: grouped_conv_nd 
+    linalg.grouped_conv_nd ins(%arg0, %arg1: memref<64x8x16x10xf32>, memref<8x32x16x3xf32>) outs(%arg2: memref<64x8x32x8xf32>)
+    return
+}
+
+// -----
+
 // CHECK-LABEL: func @depthwise_conv_1d_nwc_wcm
 func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> {
   %zero = arith.constant 0.000000e+00 : f32
diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir
index f674996e42f33..6d19665931662 100644
--- a/mlir/test/Dialect/Linalg/tile-conv.mlir
+++ b/mlir/test/Dialect/Linalg/tile-conv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -transform-interpreter -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter -canonicalize -split-input-file | FileCheck %s
 
 //  CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
 //  CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)>
@@ -41,3 +41,63 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:       linalg.conv_2d
 //  CHECK-SAME:         ins(%[[SVIN]], %[[SVKER]]
 //  CHECK-SAME:         outs(%[[SVOUT]]
+
+// -----
+
+//  CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
+//  CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)>
+//  CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
+//  CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+//  CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 6)>
+//  CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0)[s0] -> (d0 + s0 - 1)>
+
+func.func @grouped_conv_2D(%arg0 : memref<?x?x?x?x?xf32>, %arg1 : memref<?x?x?x?x?xf32>, %arg2 : memref<?x?x?x?x?xf32>) {
+  linalg.grouped_conv_nd {layouts = ["ngcs", "gfcs", "ngfs"]} ins(%arg0, %arg1 : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>) outs(%arg2 : memref<?x?x?x?x?xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.grouped_conv_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop:5 = transform.structured.tile_using_for %0 tile_sizes [2, 3, 4, 5, 6] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+//       CHECK: func @grouped_conv_2D
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?x?x?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?x?x?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?x?x?x?xf32>
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
+//   CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+//   CHECK-DAG:   %[[C6:.*]] = arith.constant 6 : index
+//   CHECK-DAG:   %[[BATCH:.*]] = memref.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[GROUPS:.*]] = memref.dim %[[ARG0]], %[[C1]]
+//   CHECK-DAG:   %[[IN_CHANNELS:.*]] = memref.dim %[[ARG0]], %[[C2]]
+//   CHECK-DAG:   %[[OUT_CHANNELS:.*]] = memref.dim %[[ARG1]], %[[C1]]
+//   CHECK-DAG:   %[[KW:.*]] = memref.dim %[[ARG1]], %[[C3]]
+//   CHECK-DAG:   %[[KH:.*]] = memref.dim %[[ARG1]], %[[C4]]
+//   CHECK-DAG:   %[[W:.*]] = memref.dim %[[ARG2]], %[[C3]]
+//   CHECK-DAG:   %[[H:.*]] = memref.dim %[[ARG2]], %[[C4]]
+//       CHECK:   scf.for %[[I:.*]] = %[[C0]] to %[[BATCH]] step %[[C2]]
+//       CHECK:     scf.for %[[J:.*]] = %[[C0]] to %[[GROUPS]] step %[[C3]]
+//       CHECK:       scf.for %[[K:.*]] = %[[C0]] to %[[OUT_CHANNELS]] step %[[C4]]
+//       CHECK:         scf.for %[[L:.*]] = %[[C0]] to %[[W]] step %[[C5]]
+//       CHECK:           scf.for %[[M:.*]] = %[[C0]] to %[[H]] step %[[C6]]
+//       CHECK:             %[[T4:.*]] = affine.min #[[MAP0]](%[[I]])[%[[BATCH]]]
+//       CHECK:             %[[T5:.*]] = affine.min #[[MAP1]](%[[J]])[%[[GROUPS]]]
+//   CHECK-DAG:             %[[T6:.*]] = affine.min #[[MAP2]](%[[K]])[%[[OUT_CHANNELS]]]
+//   CHECK-DAG:             %[[T7:.*]] = affine.min #[[MAP3]](%[[L]])[%[[W]]]
+//   CHECK-DAG:             %[[T8:.*]] = affine.min #[[MAP4]](%[[M]])[%[[H]]]
+//   CHECK-DAG:             %[[T9:.*]] = affine.apply #[[MAP5]](%[[T7]])[%[[KW]]]
+//   CHECK-DAG:             %[[T10:.*]] = affine.apply #[[MAP5]](%[[T8]])[%[[KH]]]
+//   CHECK-DAG:             %[[SVIN:.*]] = memref.subview %[[ARG0]][%[[I]], %[[J]], 0, %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[IN_CHANNELS]], %[[T9]], %[[T10]]]
+//   CHECK-DAG:             %[[SVKER:.*]] = memref.subview %[[ARG1]][%[[J]], %[[K]], 0, 0, 0] [%[[T5]], %[[T6]], %[[IN_CHANNELS]], %[[KW]], %[[KH]]]
+//   CHECK-DAG:             %[[SVOUT:.*]] = memref.subview %[[ARG2]][%[[I]], %[[J]], %[[K]], %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[T6]], %[[T7]], %[[T8]]]
+//       CHECK:             linalg.grouped_conv_nd {layouts = ["ngcs", "gfcs", "ngfs"]}
+//  CHECK-SAME:               ins(%[[SVIN]], %[[SVKER]]
+//  CHECK-SAME:               outs(%[[SVOUT]]

>From f0678a4fdf0da49c5742c65cf63614f3e41b7e35 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Mon, 1 Jul 2024 13:25:05 -0500
Subject: [PATCH 2/3] small refactor

---
 .../Dialect/Linalg/IR/LinalgInterfaces.td     | 81 ++++++++-----------
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 34 ++++++--
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    |  6 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 14 ----
 4 files changed, 64 insertions(+), 71 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 5ae481a222e3c..e8b61c2cebf3a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -203,9 +203,10 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
   let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }];
   let methods = [
     InterfaceMethod<[{
-      Returns the groups position for the input.
+      Returns the layouts of each operand (image, kernel, init).  Each layout is represented
+      by a vector of `GroupedConvDim`s.
     }],
-    "SmallVector<SmallVector<::mlir::utils::GroupedConvDim>>", "getLayoutsEnums", (ins)
+    "SmallVector<SmallVector<::mlir::utils::GroupedConvDim>>", "getOperandConvDims", (ins)
     >,
     InterfaceMethod<[{
       Returns the groups position for the input.
@@ -222,55 +223,39 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
     }],
     "int64_t", "getOutputChannelPosition", (ins)
     >,
-    InterfaceMethod<[{
-      Get number of groups. 
-    }],
-    "int64_t", "getNumGroups", (ins),
-      /*methodBody=*/[{}],
-      /*defaultImplementation=*/[{
-      return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputGroupsPosition() - 1];
-    }]>,
-    InterfaceMethod<[{
-      Get number of input channels. 
-    }],
-    "int64_t", "getNumInputChannels", (ins),
-      /*methodBody=*/[{}],
-      /*defaultImplementation=*/[{
-      return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputChannelPosition()];
-    }]>,
-    InterfaceMethod<[{
-      Get number of output channels. 
-    }],
-    "int64_t", "getNumOutputChannels", (ins),
-      /*methodBody=*/[{}],
-      /*defaultImplementation=*/[{
-      return cast<ShapedType>($_op.getDpsInits()[0].getType()).getShape()[$_op.getOutputChannelPosition()];
-    }]>,
-    InterfaceMethod<[{
-      Returns indexing maps for any spatial dimension.
-    }],
-    "::mlir::ArrayAttr", "getIteratorTypes", (ins),
-      /*methodBody=*/[{}],
-      /*defaultImplementation=*/[{
+  ];
+
+  let extraSharedClassDeclaration = [{
+    // Get number of groups.
+    int64_t getNumGroups() {
+      return cast<ShapedType>(
+        cast<::mlir::linalg::ConvolutionOpInterface>(
+          $_op.getOperation()).image().getType())
+            .getShape()[$_op.getInputGroupsPosition()];
+    }
+    // Get number of input channels. 
+    int64_t getNumInputChannels() {
+      return cast<ShapedType>(
+        cast<::mlir::linalg::ConvolutionOpInterface>(
+          $_op.getOperation()).image().getType()).getShape()[$_op.getInputChannelPosition()];
+    }
+    // Get number of output channels. 
+    int64_t getNumOutputChannels() {
+      return cast<ShapedType>($_op->getOperand(2).getType()).getShape()[$_op.getOutputChannelPosition()];
+    }
+    // Returns iterator tyes.
+    ::mlir::ArrayAttr getIteratorTypes() {
         return detail::grouped_convolution_impl::getIteratorTypes($_op);
-      }]>,
-    InterfaceMethod<[{
-      Returns strides.
-    }],
-    "::llvm::SmallVector<int64_t, 2>", "getStridesVector", (ins),
-      /*methodBody=*/[{}],
-      /*defaultImplementation=*/[{
+    }
+    // Returns strides.
+    ::llvm::SmallVector<int64_t, 2> getStridesVector() {
         return detail::convolution_impl::getStrides($_op);
-    }]>,
-    InterfaceMethod<[{
-      Returns dilations.
-    }],
-    "::llvm::SmallVector<int64_t, 2>", "getDilationsVector", (ins),
-      /*methodBody=*/[{}],
-      /*defaultImplementation=*/[{
+    }
+    // Returns dilations.
+    ::llvm::SmallVector<int64_t, 2> getDilationsVector() {
         return detail::convolution_impl::getDilations($_op);
-    }]>
-  ];
+    }
+  }];
 }
 
 def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 7db7c54a4ea09..ba67fd7e46a65 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -392,7 +392,7 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
   [AttrSizedOperandSegments, LinalgGroupedConvolutionOpInterface]> {
     
   let summary = [{
-    Performs N-D grouped convolution with switchable channel position; either first or last.
+    Performs N-D grouped convolution with parametrizable operand layouts.
   }];
   let description = [{
     Allows any number of spatial dimensions but treats all of them as contiguous.  Throughout, `S`,
@@ -490,14 +490,27 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
       MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
 
       // Implement functions necessary for LinalgOp.
-      ArrayAttr getIndexingMaps();
+      ::mlir::ArrayAttr getIndexingMaps() {
+        ::mlir::ArrayAttr cached = (*this)->getAttrOfType<::mlir::ArrayAttr>(
+            LinalgDialect::kMemoizedIndexingMapsAttrName);
+        if (cached)
+          return cached;
+
+        cached = detail::grouped_convolution_impl::createCommonIndexingMaps(
+            getContext(), getSpatialRank(), getOperandConvDims(), getStridesVector(),
+            getDilationsVector());
+
+        (*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+        return cached;
+      }
+
 
       // Implement functions necessary for GroupedConvolutionOpInterface
       int64_t getSpatialRank() {
         return detail::grouped_convolution_impl::getSpatialRank(*this);
       }
 
-      SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> getLayoutsEnums() {
+      SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> getOperandConvDims() {
         SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> layouts;
         for (auto attr : (*this).getLayoutsAttr().getValue()) {
           std::string layoutStr = cast<StringAttr>(attr).getValue().str();
@@ -513,15 +526,24 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
       }
 
       int64_t getOutputChannelPosition() {
-          return 2;
+        std::string layoutStr = cast<StringAttr>((*this).getLayoutsAttr().getValue()[2]).getValue().str();
+        size_t pos = layoutStr.find("f");
+        assert(pos != ::std::string::npos);
+        return pos;
       }
 
       int64_t getInputChannelPosition() {
-          return 2;
+        std::string layoutStr = cast<StringAttr>((*this).getLayoutsAttr().getValue()[0]).getValue().str();
+        size_t pos = layoutStr.find("c");
+        assert(pos != ::std::string::npos);
+        return pos;
       }
 
       int64_t getInputGroupsPosition() {
-          return 1;
+        std::string layoutStr = cast<StringAttr>((*this).getLayoutsAttr().getValue()[0]).getValue().str();
+        size_t pos = layoutStr.find("g");
+        assert(pos != ::std::string::npos);
+        return pos;
       }
     }];
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index c2db6670e4167..e0020a5497148 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -880,12 +880,12 @@ mlir::linalg::detail::verifyGroupedConvolutionInterface(Operation *op) {
     return failure();
   if (GroupedConvolutionOpInterface conv =
           dyn_cast<GroupedConvolutionOpInterface>(op)) {
-    const auto imageType = conv.image().getType().dyn_cast<ShapedType>();
+    const auto imageType = cast<ShapedType>(conv.image().getType());
     const auto imageRank = imageType.getRank();
     const auto kernelRank =
-        conv.filter().getType().cast<ShapedType>().getRank();
+        cast<ShapedType>(conv.filter().getType()).getRank();
     const auto initType =
-        cast<LinalgOp>(op).getDpsInits()[0].getType().dyn_cast<ShapedType>();
+        cast<ShapedType>(cast<LinalgOp>(op).getDpsInits()[0].getType());
     const auto initRank = initType.getRank();
     if (imageRank != kernelRank || imageRank != initRank)
       return op->emitError(
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 21cc22f034aa6..080692dfd6de0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1825,20 +1825,6 @@ void GroupedConvNDOp::getEffects(
   return detail::convolution_impl::getEffects(*this, effects);
 }
 
-ArrayAttr GroupedConvNDOp::getIndexingMaps() {
-  ArrayAttr cached = (*this)->getAttrOfType<ArrayAttr>(
-      LinalgDialect::kMemoizedIndexingMapsAttrName);
-  if (cached)
-    return cached;
-
-  cached = detail::grouped_convolution_impl::createCommonIndexingMaps(
-      getContext(), getSpatialRank(), getLayoutsEnums(), getStridesVector(),
-      getDilationsVector());
-
-  (*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
-  return cached;
-}
-
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//

>From 30ff1a15cad37169fe4d55c5dbc45e32c7c735c3 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Mon, 1 Jul 2024 15:36:19 -0500
Subject: [PATCH 3/3] add a couple tests to convert to generic

---
 .../Dialect/Linalg/generalize-named-ops.mlir  | 60 +++++++++++++++++++
 1 file changed, 60 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 31fac9b4b4165..ff91e25db6502 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -864,3 +864,63 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
 
   return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
 }
+
+// -----
+
+// CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3 + d5)>
+// CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+// CHECK: func @gen_grouped_1D_ngcs_gfcs_ngfs_memref
+func.func @gen_grouped_1D_ngcs_gfcs_ngfs_memref(%arg0: memref<64x8x16x10xf32>, %arg1: memref<8x32x16x3xf32>, %arg2: memref<64x8x32x8xf32>) {
+// CHECK:       linalg.generic
+// CHECK-SAME:    indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
+// CHECK-SAME:    iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]
+// CHECK-SAME:    ins(%arg0, %arg1 : {{.*}}) outs(%arg2 : {{.*}})
+// CHECK-NEXT:    ^bb0(%[[IN_0:.*]]: f32, %[[IN_1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:      %[[MUL:.*]] = arith.mulf %[[IN_0]], %[[IN_1]] : f32
+// CHECK-NEXT:      %[[ADD:.*]] = arith.addf %[[OUT]], %[[MUL]] : f32
+// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
+// CHECK-NEXT:    }
+    linalg.grouped_conv_nd {layouts = ["ngcs", "gfcs", "ngfs"]} ins(%arg0, %arg1: memref<64x8x16x10xf32>, memref<8x32x16x3xf32>) outs(%arg2: memref<64x8x32x8xf32>)
+    return
+}
+
+// -----
+
+// CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 * 3 + d6 * 2, d4 * 3 + d7 * 2)>
+// CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
+// CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+// CHECK: func @gen_grouped_2D_ngcs_gfcs_ngfs_memref
+func.func @gen_grouped_2D_ngcs_gfcs_ngfs_memref(%arg0: memref<64x2x16x26x26xf32>, %arg1: memref<2x20x16x3x3xf32>, %arg2: memref<64x2x20x8x8xf32>) {
+// CHECK:       linalg.generic
+// CHECK-SAME:    indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
+// CHECK-SAME:    iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
+// CHECK-SAME:    ins(%arg0, %arg1 : {{.*}}) outs(%arg2 : {{.*}})
+// CHECK-NEXT:    ^bb0(%[[IN_0:.*]]: f32, %[[IN_1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:      %[[MUL:.*]] = arith.mulf %[[IN_0]], %[[IN_1]] : f32
+// CHECK-NEXT:      %[[ADD:.*]] = arith.addf %[[OUT]], %[[MUL]] : f32
+// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
+// CHECK-NEXT:    }
+    linalg.grouped_conv_nd {strides = dense<3> : memref<2xi64>, dilations = dense<2> : memref<2xi64>, layouts = ["ngcs", "gfcs", "ngfs"]} ins(%arg0, %arg1: memref<64x2x16x26x26xf32>, memref<2x20x16x3x3xf32>) outs(%arg2: memref<64x2x20x8x8xf32>)
+    return
+}
+
+// -----
+
+// CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d3 * 3 + d6 * 2, d4 * 3 + d7 * 2, d5)>
+// CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d6, d7, d2, d5)>
+// CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d3, d4, d2)>
+// CHECK: func @gen_grouped_2D_ngsc_gsfc_ngsf_memref
+func.func @gen_grouped_2D_ngsc_gsfc_ngsf_memref(%arg0: memref<64x2x26x26x16xf32>, %arg1: memref<2x3x3x20x16xf32>, %arg2: memref<64x2x8x8x20xf32>) {
+// CHECK:       linalg.generic
+// CHECK-SAME:    indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
+// CHECK-SAME:    iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
+// CHECK-SAME:    ins(%arg0, %arg1 : {{.*}}) outs(%arg2 : {{.*}})
+// CHECK-NEXT:    ^bb0(%[[IN_0:.*]]: f32, %[[IN_1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:      %[[MUL:.*]] = arith.mulf %[[IN_0]], %[[IN_1]] : f32
+// CHECK-NEXT:      %[[ADD:.*]] = arith.addf %[[OUT]], %[[MUL]] : f32
+// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
+// CHECK-NEXT:    }
+    linalg.grouped_conv_nd {strides = dense<3> : memref<2xi64>, dilations = dense<2> : memref<2xi64>, layouts = ["ngsc", "gsfc", "ngsf"]} ins(%arg0, %arg1: memref<64x2x26x26x16xf32>, memref<2x3x3x20x16xf32>) outs(%arg2: memref<64x2x8x8x20xf32>)
+    return
+}



More information about the Mlir-commits mailing list