[clang-tools-extra] [mlir] [llvm] [clang] Generalize depthwise conv (PR #75017)

via cfe-commits cfe-commits at lists.llvm.org
Sun Dec 10 19:44:12 PST 2023


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/75017

>From f6ca8308f15b1200380b04a6a79cd893ef89a709 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 7 Dec 2023 21:57:20 -0600
Subject: [PATCH 01/22] Rough prototype of depthwise conv with switchable
 channel dim

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  64 ++++++++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 113 ++++++++++++++++++
 2 files changed, 177 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 751edd02288301..716d67fae15886 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -384,6 +384,70 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// DepthwiseConv1DOp op.
+//===----------------------------------------------------------------------===//
+
+def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
+  [AttrSizedOperandSegments, LinalgConvolutionOpInterface]> {
+    
+  let summary = [{
+    Performs 1-D depthwise convolution with switchable channel position; either first or last.
+  }];
+  let description = [{
+
+    Channel position is determined by the `BoolAttr` `channel_first`. If true,
+    layout is
+
+    Input: `NCW`
+    Kernel: `CW`
+
+    otherwise
+
+    Input: `NWC`
+    Kernel: `WC`
+
+
+  }];
+
+    let arguments = (ins
+      Variadic<TensorOrMemref>:$inputs,
+      Variadic<TensorOrMemref>:$inits,
+      DefaultValuedOptionalAttr<BoolAttr, "true">:$channel_first
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let builders = [
+      OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits,
+        "bool":$channel_first,
+        "function_ref<void(OpBuilder &, Location, ValueRange)>",
+        CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      // Declare functions necessary for LinalgStructuredInterface.
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+      ArrayAttr getIndexingMaps();
+      std::string getLibraryCallName() {
+        return "op_has_no_registered_library_name";
+      }
+
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                              Block &block, ArrayRef<NamedAttribute> attrs);
+
+      // Implement functions necessary for DestinationStyleOpInterface.
+      static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+                                mlir::ArrayRef<mlir::NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+      static unsigned getNumRegionArgs() { return 3; }
+      MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+    }];
+}
 
 //===----------------------------------------------------------------------===//
 // Transpose op.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 58af9995548e93..99ef37c6dc1221 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1672,6 +1672,119 @@ LogicalResult ReduceOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// DepthwiseConv1DOp
+//===----------------------------------------------------------------------===//
+
+void DepthwiseConv1DOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                                      ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "DepthwiseConv1DOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+  SmallVector<Value> yields;
+
+  Value value1 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(0));
+  Value value2 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
+void DepthwiseConv1DOp::build(
+    OpBuilder &builder, OperationState &result, ValueRange inputs,
+    ValueRange inits, bool channel_first,
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+    ArrayRef<NamedAttribute> attributes) {
+  build(builder, result, TypeRange{}, inputs, inits, channel_first);
+  result.addAttribute(getChannelFirstAttrName(result.name),
+                      builder.getBoolAttr(channel_first));
+  result.addAttributes(attributes);
+
+  // Add output types for `RankedTensorType` output arguments.
+  for (Value init : inits) {
+    Type initType = init.getType();
+    if (llvm::isa<RankedTensorType>(initType))
+      result.addTypes(initType);
+  }
+
+  if (bodyBuild)
+    buildGenericRegion(builder, result.location, *result.regions.front(),
+                       inputs, inits, bodyBuild);
+}
+
+SmallVector<utils::IteratorType> DepthwiseConv1DOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{
+      utils::IteratorType::parallel, utils::IteratorType::parallel,
+      utils::IteratorType::parallel, utils::IteratorType::reduction};
+  ;
+}
+
+ArrayAttr DepthwiseConv1DOp::getIndexingMaps() {
+  ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(LinalgDialect::kMemoizedIndexingMapsAttrName);
+  if (cached)
+    return cached;
+
+  MLIRContext *context = getContext();
+  SmallVector<AffineExpr> symbolBindings(
+      {getAffineSymbolExpr(0, context), getAffineSymbolExpr(1, context),
+       getAffineSymbolExpr(2, context), getAffineConstantExpr(1, context),
+       getAffineSymbolExpr(4, context), getAffineConstantExpr(1, context)});
+  // Don't actually do something stupid like this
+  SmallVector<StringRef> rawStrings =
+      (getChannelFirst())
+          ? SmallVector<StringRef>({
+                "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
+                "-> (d0, d2, d1 * s3 + d3 * s5)>",
+                "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
+                "-> (d2, d3)>",
+                "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
+                "-> (d0, d2, d1)>",
+            })
+          : SmallVector<StringRef>({
+                "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
+                "-> (d0, d1 * s3 + d3 * s5, d2)>",
+                "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
+                "-> (d3, d2)>",
+                "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
+                "-> (d0, d1, d2)>",
+            });
+  SmallVector<AffineMap> maps(llvm::map_range(rawStrings, [&](StringRef &m) {
+    return simplifyAffineMap(
+        llvm::cast<AffineMapAttr>(parseAttribute(m, context))
+            .getValue()
+            .replaceDimsAndSymbols({}, symbolBindings, 4, 0));
+  }));
+
+  cached = Builder(context).getAffineMapArrayAttr(maps);
+  getOperation()->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+  return cached;
+}
+
+void DepthwiseConv1DOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+                        getDpsInits());
+}
+
+ParseResult DepthwiseConv1DOp::parse(OpAsmParser &parser,
+                                     OperationState &result) {
+  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+                                getRegionBuilder());
+}
+
+void DepthwiseConv1DOp::print(OpAsmPrinter &p) {
+  printNamedStructuredOp(p, getOperation(), getDpsInputs(), getDpsInits());
+}
+
+LogicalResult DepthwiseConv1DOp::verify() { return success(); }
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//

>From d9097977f7db4c4b484d4e9cd922f7e26deec9e0 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 7 Dec 2023 23:10:51 -0600
Subject: [PATCH 02/22] Support strides and dilations

---
 .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td        | 6 +++++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp                 | 9 ++++++---
 2 files changed, 11 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 716d67fae15886..d8e5b0fcf627a9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -413,7 +413,11 @@ def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
     let arguments = (ins
       Variadic<TensorOrMemref>:$inputs,
       Variadic<TensorOrMemref>:$inits,
-      DefaultValuedOptionalAttr<BoolAttr, "true">:$channel_first
+      DefaultValuedOptionalAttr<BoolAttr, "true">:$channel_first,
+      DefaultValuedOptionalAttr<RankedI64ElementsAttr<[1]>,
+        "{ static_cast<int64_t>(1) }">:$strides,
+      DefaultValuedOptionalAttr<RankedI64ElementsAttr<[1]>,
+        "{ static_cast<int64_t>(1) }">:$dilations
     );
     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
     let regions = (region AnyRegion:$region);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 99ef37c6dc1221..c56562dc6af778 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1726,15 +1726,18 @@ SmallVector<utils::IteratorType> DepthwiseConv1DOp::getIteratorTypesArray() {
 }
 
 ArrayAttr DepthwiseConv1DOp::getIndexingMaps() {
-  ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(LinalgDialect::kMemoizedIndexingMapsAttrName);
+  ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(
+      LinalgDialect::kMemoizedIndexingMapsAttrName);
   if (cached)
     return cached;
 
   MLIRContext *context = getContext();
   SmallVector<AffineExpr> symbolBindings(
       {getAffineSymbolExpr(0, context), getAffineSymbolExpr(1, context),
-       getAffineSymbolExpr(2, context), getAffineConstantExpr(1, context),
-       getAffineSymbolExpr(4, context), getAffineConstantExpr(1, context)});
+       getAffineSymbolExpr(2, context),
+       getAffineConstantExpr(getStrides().getValues<int64_t>()[0], context),
+       getAffineSymbolExpr(4, context),
+       getAffineConstantExpr(getDilations().getValues<int64_t>()[0], context)});
   // Don't actually do something stupid like this
   SmallVector<StringRef> rawStrings =
       (getChannelFirst())

>From 6d3a22a24c33f50bf774eea8d1c39cbeb0168c4d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 8 Dec 2023 01:06:17 -0600
Subject: [PATCH 03/22] Improve map creation

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 19 +++---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 58 ++++++++-----------
 2 files changed, 35 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index d8e5b0fcf627a9..304c621621c4b2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -395,18 +395,19 @@ def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
     Performs 1-D depthwise convolution with switchable channel position; either first or last.
   }];
   let description = [{
+    Domain: N, OW, C, KW
 
-    Channel position is determined by the `BoolAttr` `channel_first`. If true,
-    layout is
+    Layout of operands is determined by the `channel_first` `BoolAttr`:
 
-    Input: `NCW`
-    Kernel: `CW`
-
-    otherwise
-
-    Input: `NWC`
-    Kernel: `WC`
+    `channel_first == true`:
+      Input: `NCW`
+      Kernel: `CW`
+      Output: `NCW`
 
+    `channel_first == false`:
+      Input: `NWC`
+      Kernel: `WC`
+      Output: `NWC`
 
   }];
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c56562dc6af778..bf1d8d565e64d5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1732,39 +1732,31 @@ ArrayAttr DepthwiseConv1DOp::getIndexingMaps() {
     return cached;
 
   MLIRContext *context = getContext();
-  SmallVector<AffineExpr> symbolBindings(
-      {getAffineSymbolExpr(0, context), getAffineSymbolExpr(1, context),
-       getAffineSymbolExpr(2, context),
-       getAffineConstantExpr(getStrides().getValues<int64_t>()[0], context),
-       getAffineSymbolExpr(4, context),
-       getAffineConstantExpr(getDilations().getValues<int64_t>()[0], context)});
-  // Don't actually do something stupid like this
-  SmallVector<StringRef> rawStrings =
-      (getChannelFirst())
-          ? SmallVector<StringRef>({
-                "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
-                "-> (d0, d2, d1 * s3 + d3 * s5)>",
-                "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
-                "-> (d2, d3)>",
-                "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
-                "-> (d0, d2, d1)>",
-            })
-          : SmallVector<StringRef>({
-                "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
-                "-> (d0, d1 * s3 + d3 * s5, d2)>",
-                "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
-                "-> (d3, d2)>",
-                "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
-                "-> (d0, d1, d2)>",
-            });
-  SmallVector<AffineMap> maps(llvm::map_range(rawStrings, [&](StringRef &m) {
-    return simplifyAffineMap(
-        llvm::cast<AffineMapAttr>(parseAttribute(m, context))
-            .getValue()
-            .replaceDimsAndSymbols({}, symbolBindings, 4, 0));
-  }));
-
-  cached = Builder(context).getAffineMapArrayAttr(maps);
+
+  // Domain: (n, w, c, kw)
+  AffineExpr n = getAffineDimExpr(0, context);
+  AffineExpr w = getAffineDimExpr(1, context);
+  AffineExpr c = getAffineDimExpr(2, context);
+  AffineExpr kw = getAffineDimExpr(3, context);
+
+  // Temp subsitute for channel position attr
+  int64_t channelPos = (getChannelFirst()) ? 1 : 2;
+  // Initialze operand accesses in nw order and insert c according to channel
+  // position
+  SmallVector<AffineExpr> inExprs(
+      {n, w * getStrides().getValues<int64_t>()[0] +
+              kw * getDilations().getValues<int64_t>()[0]});
+  SmallVector<AffineExpr> kExprs({kw});
+  SmallVector<AffineExpr> outExprs({n, w});
+  inExprs.insert(inExprs.begin() + channelPos, c);
+  kExprs.insert(
+      channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
+  outExprs.insert(outExprs.begin() + channelPos, c);
+
+  cached = Builder(context).getAffineMapArrayAttr(
+      {AffineMap::get(4, 0, inExprs, context),
+       AffineMap::get(4, 0, kExprs, context),
+       AffineMap::get(4, 0, outExprs, context)});
   getOperation()->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
   return cached;
 }

>From 13abaefd0dda4c616fdc97225801ed1b5389f6a0 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 8 Dec 2023 01:22:58 -0600
Subject: [PATCH 04/22] Add a couple verification regression tests

---
 mlir/test/Dialect/Linalg/named-ops.mlir | 36 +++++++++++++++++++++++++
 1 file changed, 36 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 5ca35155854d33..f281c138727640 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1,5 +1,41 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
+// CHECK-LABEL: func @gen_depthwise_channel_first_memref
+func.func @gen_depthwise_channel_first_memref(%arg0: memref<64x16x10xf32>, %arg1: memref<16x3xf32>, %arg2: memref<64x16x8xf32>) {
+  // CHECK: depthwise_conv_1d {{.*}}channel_first = true
+    linalg.depthwise_conv_1d {channel_first = true} ins(%arg0, %arg1: memref<64x16x10xf32>, memref<16x3xf32>) outs(%arg2: memref<64x16x8xf32>)
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @gen_depthwise_channel_last_memref
+func.func @gen_depthwise_channel_last_memref(%arg0: memref<64x26x16xf32>, %arg1: memref<3x16xf32>, %arg2: memref<64x8x16xf32>) {
+  // CHECK: depthwise_conv_1d {{.*}}channel_first = false 
+    linalg.depthwise_conv_1d {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<64x26x16xf32>, memref<3x16xf32>) outs(%arg2: memref<64x8x16xf32>)
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @gen_depthwise_channel_first_tensor
+func.func @gen_depthwise_channel_first_tensor(%arg0: tensor<64x16x10xf32>, %arg1: tensor<16x3xf32>, %arg2: tensor<64x16x8xf32>) -> tensor<64x16x8xf32> {
+  // CHECK: depthwise_conv_1d {{.*}}channel_first = true
+    %0 = linalg.depthwise_conv_1d {channel_first = true} ins(%arg0, %arg1: tensor<64x16x10xf32>, tensor<16x3xf32>) outs(%arg2: tensor<64x16x8xf32>) -> tensor<64x16x8xf32> 
+    return %0 : tensor<64x16x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @gen_depthwise_channel_last_tensor
+func.func @gen_depthwise_channel_last_tensor(%arg0: tensor<64x26x16xf32>, %arg1: tensor<3x16xf32>, %arg2: tensor<64x8x16xf32>) -> tensor<64x8x16xf32> {
+  // CHECK: depthwise_conv_1d {{.*}}channel_first = false 
+    %0 = linalg.depthwise_conv_1d {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: tensor<64x26x16xf32>, tensor<3x16xf32>) outs(%arg2: tensor<64x8x16xf32>) -> tensor<64x8x16xf32>
+    return %0 : tensor<64x8x16xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @depthwise_conv_1d_nwc_wcm
 func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> {
   %zero = arith.constant 0.000000e+00 : f32

>From 840fadb6d12aabf7071488e3d3a3d8cd22e88739 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 8 Dec 2023 02:03:38 -0600
Subject: [PATCH 05/22] Add convert-linalg-to-loops regression tests

---
 mlir/test/Dialect/Linalg/loops.mlir | 69 ++++++++++++++++++++++++++---
 1 file changed, 64 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 8c13422fd63833..4e9bd0d8118a6c 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -1,12 +1,11 @@
-// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s
-// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefix=CHECKPARALLEL %s
+// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s --check-prefixes=COMMON,CHECK
+// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefixes=COMMON,CHECKPARALLEL %s
 
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
 // RUN: mlir-opt %s -convert-linalg-to-loops -test-lower-to-llvm -o=/dev/null 2>&1
 
-// CHECK: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
-
-// CHECKPARALLEL: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// COMMON: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// COMMON: #[[$stride3Dilation2:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 2)>
 
 func.func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
   %c0 = arith.constant 0 : index
@@ -843,6 +842,66 @@ func.func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32
 //       CHECKPARALLEL:         %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
 //       CHECKPARALLEL:         store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
 
+
+func.func @gen_depthwise_channel_first_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
+    linalg.depthwise_conv_1d {channel_first = true} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+    return
+}
+
+//    COMMON-LABEL: func @gen_depthwise_channel_first_memref
+//    COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//    COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//    COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//         COMMON: %[[c0:.*]] = arith.constant 0 : index
+//         COMMON: %[[c1:.*]] = arith.constant 1 : index
+//         COMMON: %[[c2:.*]] = arith.constant 2 : index
+//         COMMON: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?xf32>
+//         COMMON: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c1]] : memref<?x?x?xf32>
+//         COMMON: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?xf32>
+//         COMMON: %[[dim3:.*]] = memref.dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
+//          CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+//          CHECK:   scf.for %[[ow:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+//          CHECK:     scf.for %[[c:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
+//  CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]])  to (%[[dim0]], %[[dim3]], %[[dim1]])
+//         COMMON:       scf.for %[[kw:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+//         COMMON:         %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[ow]], %[[kw]])
+//         COMMON:         %[[vb:.*]] = memref.load %[[arg0]][%[[n]], %[[c]], %[[aff]]] : memref<?x?x?xf32>
+//         COMMON:         %[[va:.*]] = memref.load %[[arg1]][%[[c]], %[[kw]]] : memref<?x?xf32>
+//         COMMON:         %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[c]], %[[ow]]] : memref<?x?x?xf32>
+//         COMMON:         %[[inc:.*]] = arith.mulf %[[vb]], %[[va]] : f32
+//         COMMON:         %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
+//         COMMON:         store %[[res]], %[[arg2]][%[[n]], %[[c]], %[[ow]]] : memref<?x?x?xf32>
+
+
+func.func @gen_depthwise_channel_last_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
+    linalg.depthwise_conv_1d {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+    return
+}
+
+//    COMMON-LABEL: func @gen_depthwise_channel_last_memref
+//    COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//    COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//    COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//         COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
+//         COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
+//         COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
+//         COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?xf32>
+//         COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<?x?x?xf32>
+//         COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c0]] : memref<?x?xf32>
+//         COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
+//          CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+//          CHECK:   scf.for %[[ow:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+//          CHECK:     scf.for %[[c:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
+//  CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]])  to (%[[dim0]], %[[dim3]], %[[dim1]])
+//         COMMON:       scf.for %[[kw:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+//         COMMON:         %[[aff:.*]] = affine.apply #[[$stride3Dilation2]](%[[ow]], %[[kw]])
+//         COMMON:         %[[vb:.*]] = memref.load %[[arg0]][%[[n]], %[[aff]], %[[c]]] : memref<?x?x?xf32>
+//         COMMON:         %[[va:.*]] = memref.load %[[arg1]][%[[kw]], %[[c]]] : memref<?x?xf32>
+//         COMMON:         %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[ow]], %[[c]]] : memref<?x?x?xf32>
+//         COMMON:         %[[inc:.*]] = arith.mulf %[[vb]], %[[va]] : f32
+//         COMMON:         %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
+//         COMMON:         store %[[res]], %[[arg2]][%[[n]], %[[ow]], %[[c]]] : memref<?x?x?xf32>
+
 // -----
 
 func.func @lower_to_loops_with_rank_reducing_subviews(

>From 488c1702615fed9a932e3d6f04eee8f286e07cd5 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 8 Dec 2023 15:14:21 -0600
Subject: [PATCH 06/22] define depthwise conv interface (works but messy)

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |   6 +
 .../Dialect/Linalg/IR/LinalgInterfaces.td     |  48 ++++++-
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  11 +-
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    |  60 +++++++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 119 ++++++++++++------
 5 files changed, 203 insertions(+), 41 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 6c8240267e7d05..c3eb3ff665c20d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -28,6 +28,7 @@ namespace mlir {
 namespace linalg {
 class IteratorTypeAttr;
 class LinalgOp;
+class DepthwiseConvolutionOpInterface;
 
 namespace detail {
 /// Implementation of the method that check if given operands
@@ -115,6 +116,11 @@ bool isaCopyOpInterface(LinalgOp linalgOp);
 
 namespace detail {
 
+// DictionaryAttr getStridesDict(DepthwiseConvolutionOpInterface op);
+DenseIntElementsAttr getStridesAttr(DepthwiseConvolutionOpInterface op);
+DenseIntElementsAttr getDilationsAttr(DepthwiseConvolutionOpInterface op);
+BoolAttr getChannelFirstAttr(DepthwiseConvolutionOpInterface op);
+ArrayAttr getIndexingMaps(DepthwiseConvolutionOpInterface op);
 /// Returns true if the block contains a contraction of the following form:
 ///
 ///   %0 = <elemwise>(permutation-of(cu(block-argument-0),
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9b..e151d8c634dd87 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -178,6 +178,8 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
   ];
 }
 
+
+
 def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
   let description = [{
     A fill operation is defined in general terms:
@@ -871,7 +873,51 @@ def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
           return {};
         }]
       >
-  ];
+  ]; 
 }
 
+def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpInterface", [LinalgConvolutionOpInterface]> {
+    let description = [{
+      A depthwise convolution is defined in general terms:
+      1. it is a convolution as defined by `ConvolutionOpInterface`
+      1. `in_channels = K * out_channels` for some integer `K`
+      4. The indexing maps of the input have expressions that satisfy
+      ```
+        AffineExpr ::== AffineDimExpr | ConvolvedExpr
+        ConvolvedExpr ::== MulExpr (`+` MulExpr)+
+        MulExpr ::== AffineDimExpr (`*` (AffineConstantExpr | AffineSymbolExpr))?
+      ```
+    }];
+    let cppNamespace = "::mlir::linalg";
+    let verify = [{ return detail::verifyConvolutionInterface($_op); }];
+    let methods = [
+      InterfaceMethod<[{
+        Returns strides attribute.
+      }],
+      "::mlir::DenseIntElementsAttr", "getStridesAttr", (ins),
+        /*methodBody=*/[{}],
+        /*defaultImplementation=*/[{
+          printf("whatever\n");
+          return detail::getStridesAttr($_op);
+        }]>,
+      InterfaceMethod<[{
+        Returns dilations attribute.
+      }],
+      "::mlir::DenseIntElementsAttr", "getDilationsAttr", (ins),
+        /*methodBody=*/[{}],
+        /*defaultImplementation=*/[{
+          printf("whatever\n");
+          return detail::getDilationsAttr($_op);
+        }]>,
+      InterfaceMethod<[{
+        Returns channel dim attribute.
+      }],
+      "::mlir::BoolAttr", "getChannelFirstAttr", (ins),
+        /*methodBody=*/[{}],
+        /*defaultImplementation=*/[{
+          return detail::getChannelFirstAttr($_op);
+        }]>
+    ];
+  }
+
 #endif // LINALG_IR_LINALGINTERFACES
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 304c621621c4b2..37d668bc43f1b2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -389,7 +389,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
 //===----------------------------------------------------------------------===//
 
 def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
-  [AttrSizedOperandSegments, LinalgConvolutionOpInterface]> {
+  [AttrSizedOperandSegments, LinalgDepthwiseConvolutionOpInterface]> {
     
   let summary = [{
     Performs 1-D depthwise convolution with switchable channel position; either first or last.
@@ -414,10 +414,10 @@ def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
     let arguments = (ins
       Variadic<TensorOrMemref>:$inputs,
       Variadic<TensorOrMemref>:$inits,
-      DefaultValuedOptionalAttr<BoolAttr, "true">:$channel_first,
-      DefaultValuedOptionalAttr<RankedI64ElementsAttr<[1]>,
+      DefaultValuedAttr<BoolAttr, "true">:$channel_first,
+      DefaultValuedAttr<RankedI64ElementsAttr<[1]>,
         "{ static_cast<int64_t>(1) }">:$strides,
-      DefaultValuedOptionalAttr<RankedI64ElementsAttr<[1]>,
+      DefaultValuedAttr<RankedI64ElementsAttr<[1]>,
         "{ static_cast<int64_t>(1) }">:$dilations
     );
     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
@@ -435,7 +435,8 @@ def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
     let extraClassDeclaration = structuredOpsBaseDecls # [{
       // Declare functions necessary for LinalgStructuredInterface.
       SmallVector<utils::IteratorType> getIteratorTypesArray();
-      ArrayAttr getIndexingMaps();
+      // Figure out how to get implementation in `LinalgDepthwiseConvolutionOpInterface` working 
+      ArrayAttr getIndexingMaps() { return detail::getIndexingMaps(*this); };
       std::string getLibraryCallName() {
         return "op_has_no_registered_library_name";
       }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ba419d32f22a3e..3c8ae952f49bdb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -638,6 +638,66 @@ enum class MatchConvolutionResult {
 };
 } // namespace mlir::linalg::detail
 
+DenseIntElementsAttr mlir::linalg::detail::getStridesAttr(DepthwiseConvolutionOpInterface op) {
+  auto maybeStridesAttr = op.getStridesAttr();
+  maybeStridesAttr.dump();
+  if (!maybeStridesAttr) {
+    OpBuilder builder(op.getContext());
+    int64_t numSpatialDims = op.image().getType().cast<ShapedType>().getRank() - 2;
+    auto type = RankedTensorType::get({static_cast<int64_t>(numSpatialDims)},
+                                    builder.getI64Type());
+    SmallVector<int64_t> strides(numSpatialDims, 1);
+    return DenseIntElementsAttr::get(type, strides);
+  }
+  return op.getStridesAttr();
+}
+
+DenseIntElementsAttr mlir::linalg::detail::getDilationsAttr(DepthwiseConvolutionOpInterface op) {
+  return op.getDilationsAttr();
+}
+
+BoolAttr mlir::linalg::detail::getChannelFirstAttr(DepthwiseConvolutionOpInterface op) {
+  return op.getChannelFirstAttr();
+}
+
+ArrayAttr mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface op) {
+  ArrayAttr cached = op->getAttrOfType<ArrayAttr>(
+      LinalgDialect::kMemoizedIndexingMapsAttrName);
+  if (cached)
+    return cached;
+
+  MLIRContext *ctx = op.getContext();
+  auto numSpatial = op.image().getType().cast<ShapedType>().getRank() - 2;
+  // Domain: (n, w, c, kw)
+  AffineExpr n = getAffineDimExpr(0, ctx);
+  SmallVector<AffineExpr> s(llvm::map_range(llvm::seq<int64_t>(1, numSpatial + 1), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+  AffineExpr c = getAffineDimExpr(numSpatial + 1, ctx);
+  SmallVector<AffineExpr> ks(llvm::map_range(llvm::seq<int64_t>(numSpatial + 2, 2 * (numSpatial + 1)), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+  // Temp subsitute for channel position attr
+  int64_t channelPos = (op.getChannelFirstAttr().getValue()) ? 1 : numSpatial + 1;
+  
+  // Initialze operand accesses in nw order and insert c according to channel
+  // position
+  SmallVector<AffineExpr> inExprs = {n}, outExprs = {n};
+  for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, op.getStridesAttr().getValues<int64_t>(), op.getDilationsAttr().getValues<int64_t>())) {
+    inExprs.push_back(sp * st + ksp * di);
+    outExprs.push_back(sp);
+  }
+  SmallVector<AffineExpr> kExprs(ks);
+  inExprs.insert(inExprs.begin() + channelPos, c);
+  kExprs.insert(
+      channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
+  outExprs.insert(outExprs.begin() + channelPos, c);
+
+
+  cached = Builder(ctx).getAffineMapArrayAttr(
+      {AffineMap::get(4, 0, inExprs, ctx),
+       AffineMap::get(4, 0, kExprs, ctx),
+       AffineMap::get(4, 0, outExprs, ctx)});
+  op->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+  return cached;
+}
+
 mlir::linalg::detail::MatchConvolutionResult
 mlir::linalg::detail::isConvolutionInterfaceImpl(
     Operation *op, ConvolutionDimensions *dimensions) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index bf1d8d565e64d5..6e904aa03c1eaf 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1725,41 +1725,90 @@ SmallVector<utils::IteratorType> DepthwiseConv1DOp::getIteratorTypesArray() {
   ;
 }
 
-ArrayAttr DepthwiseConv1DOp::getIndexingMaps() {
-  ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(
-      LinalgDialect::kMemoizedIndexingMapsAttrName);
-  if (cached)
-    return cached;
-
-  MLIRContext *context = getContext();
-
-  // Domain: (n, w, c, kw)
-  AffineExpr n = getAffineDimExpr(0, context);
-  AffineExpr w = getAffineDimExpr(1, context);
-  AffineExpr c = getAffineDimExpr(2, context);
-  AffineExpr kw = getAffineDimExpr(3, context);
-
-  // Temp subsitute for channel position attr
-  int64_t channelPos = (getChannelFirst()) ? 1 : 2;
-  // Initialze operand accesses in nw order and insert c according to channel
-  // position
-  SmallVector<AffineExpr> inExprs(
-      {n, w * getStrides().getValues<int64_t>()[0] +
-              kw * getDilations().getValues<int64_t>()[0]});
-  SmallVector<AffineExpr> kExprs({kw});
-  SmallVector<AffineExpr> outExprs({n, w});
-  inExprs.insert(inExprs.begin() + channelPos, c);
-  kExprs.insert(
-      channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
-  outExprs.insert(outExprs.begin() + channelPos, c);
-
-  cached = Builder(context).getAffineMapArrayAttr(
-      {AffineMap::get(4, 0, inExprs, context),
-       AffineMap::get(4, 0, kExprs, context),
-       AffineMap::get(4, 0, outExprs, context)});
-  getOperation()->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
-  return cached;
-}
+// ArrayAttr getNDIndexngMaps(DepthwiseConvolutionOpInterface op) {
+//   ArrayAttr cached = op->getAttrOfType<ArrayAttr>(
+//       LinalgDialect::kMemoizedIndexingMapsAttrName);
+//   if (cached)
+//     return cached;
+
+//   MLIRContext *ctx = op.getContext();
+//   auto numSpatial = op.image().getType().cast<ShapedType>().getRank() - 2;
+//   // Domain: (n, w, c, kw)
+//   AffineExpr n = getAffineDimExpr(0, ctx);
+//   SmallVector<AffineExpr> s(llvm::map_range(llvm::seq<int64_t>(1, numSpatial), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+//   AffineExpr c = getAffineDimExpr(numSpatial + 1, ctx);
+//   SmallVector<AffineExpr> ks(llvm::map_range(llvm::seq<int64_t>(numSpatial + 2, 2 * (numSpatial + 1)), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+
+//   // Temp subsitute for channel position attr
+//   int64_t channelPos = (1) ? 1 : numSpatial + 1;
+  
+//   // Initialze operand accesses in nw order and insert c according to channel
+//   // position
+//   SmallVector<AffineExpr> inExprs, outExprs = {n};
+//   for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, op.getStridesAttr().getValues<int64_t>(), op.getDilationsAttr().getValues<int64_t>())) {
+//     inExprs.push_back(sp * st + ksp * di);
+//     outExprs.push_back(sp);
+//   }
+//   SmallVector<AffineExpr> kExprs(ks);
+//   inExprs.insert(inExprs.begin() + channelPos, c);
+//   kExprs.insert(
+//       channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
+//   outExprs.insert(outExprs.begin() + channelPos, c);
+
+//   n.dump();
+//   for (auto sp : s)
+//     sp.dump();
+//   c.dump();
+//   for (auto ksp : ks)
+//     ksp.dump();
+
+//   for (auto b : inExprs)
+//     b.dump();
+
+//   cached = Builder(ctx).getAffineMapArrayAttr(
+//       {AffineMap::get(4, 0, inExprs, ctx),
+//        AffineMap::get(4, 0, kExprs, ctx),
+//        AffineMap::get(4, 0, outExprs, ctx)});
+//   op->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+//   return cached;
+// }
+
+// ArrayAttr DepthwiseConv1DOp::getIndexingMaps() {
+//   return getNDIndexngMaps(this);
+//   // ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(
+//   //     LinalgDialect::kMemoizedIndexingMapsAttrName);
+//   // if (cached)
+//   //   return cached;
+
+//   // MLIRContext *context = getContext();
+
+//   // // Domain: (n, w, c, kw)
+//   // AffineExpr n = getAffineDimExpr(0, context);
+//   // AffineExpr w = getAffineDimExpr(1, context);
+//   // AffineExpr c = getAffineDimExpr(2, context);
+//   // AffineExpr kw = getAffineDimExpr(3, context);
+
+//   // // Temp subsitute for channel position attr
+//   // int64_t channelPos = (getChannelFirst()) ? 1 : 2;
+//   // // Initialze operand accesses in nw order and insert c according to channel
+//   // // position
+//   // SmallVector<AffineExpr> inExprs(
+//   //     {n, w * getStrides().getValues<int64_t>()[0] +
+//   //             kw * getDilations().getValues<int64_t>()[0]});
+//   // SmallVector<AffineExpr> kExprs({kw});
+//   // SmallVector<AffineExpr> outExprs({n, w});
+//   // inExprs.insert(inExprs.begin() + channelPos, c);
+//   // kExprs.insert(
+//   //     channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
+//   // outExprs.insert(outExprs.begin() + channelPos, c);
+
+//   // cached = Builder(context).getAffineMapArrayAttr(
+//   //     {AffineMap::get(4, 0, inExprs, context),
+//   //      AffineMap::get(4, 0, kExprs, context),
+//   //      AffineMap::get(4, 0, outExprs, context)});
+//   // getOperation()->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+//   // return cached;
+// }
 
 void DepthwiseConv1DOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>

>From e22dc06986c9bd7d72440b61ba398f49697abc0e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 8 Dec 2023 17:14:31 -0600
Subject: [PATCH 07/22] Add 2d depthwise with tests

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  69 +++++++-
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    |  41 +++--
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 159 ++++++++----------
 mlir/test/Dialect/Linalg/loops.mlir           |  35 ++++
 mlir/test/Dialect/Linalg/named-ops.mlir       |   9 +
 5 files changed, 212 insertions(+), 101 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 37d668bc43f1b2..866f1f1fde0d51 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -385,7 +385,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
 }
 
 //===----------------------------------------------------------------------===//
-// DepthwiseConv1DOp op.
+// DepthwiseConvNDOp ops.
 //===----------------------------------------------------------------------===//
 
 def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
@@ -455,6 +455,73 @@ def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
     }];
 }
 
+def DepthwiseConv2DOp : LinalgStructuredBase_Op<"depthwise_conv_2d",
+  [AttrSizedOperandSegments, LinalgDepthwiseConvolutionOpInterface]> {
+    
+  let summary = [{
+    Performs 2-D depthwise convolution with switchable channel position; either first or last.
+  }];
+  let description = [{
+    Domain: N, OH, OW, C, KH, KW
+
+    Layout of operands is determined by the `channel_first` `BoolAttr`:
+
+    `channel_first == true`:
+      Input: `NCHW`
+      Kernel: `CHW`
+      Output: `NCHW`
+
+    `channel_first == false`:
+      Input: `NHWC`
+      Kernel: `HWC`
+      Output: `NHWC`
+
+  }];
+
+    let arguments = (ins
+      Variadic<TensorOrMemref>:$inputs,
+      Variadic<TensorOrMemref>:$inits,
+      DefaultValuedAttr<BoolAttr, "true">:$channel_first,
+      DefaultValuedAttr<RankedI64ElementsAttr<[2]>,
+        "{ static_cast<int64_t>(1) }">:$strides,
+      DefaultValuedAttr<RankedI64ElementsAttr<[2]>,
+        "{ static_cast<int64_t>(1) }">:$dilations
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let builders = [
+      OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits,
+        "bool":$channel_first,
+        "function_ref<void(OpBuilder &, Location, ValueRange)>",
+        CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      // Declare functions necessary for LinalgStructuredInterface.
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+      // Figure out how to get implementation in `LinalgDepthwiseConvolutionOpInterface` working 
+      ArrayAttr getIndexingMaps() { return detail::getIndexingMaps(*this); };
+      std::string getLibraryCallName() {
+        return "op_has_no_registered_library_name";
+      }
+
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                              Block &block, ArrayRef<NamedAttribute> attrs);
+
+      // Implement functions necessary for DestinationStyleOpInterface.
+      static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+                                mlir::ArrayRef<mlir::NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+      static unsigned getNumRegionArgs() { return 3; }
+      MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Transpose op.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 3c8ae952f49bdb..cdd830c9d90a68 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -638,29 +638,34 @@ enum class MatchConvolutionResult {
 };
 } // namespace mlir::linalg::detail
 
-DenseIntElementsAttr mlir::linalg::detail::getStridesAttr(DepthwiseConvolutionOpInterface op) {
+DenseIntElementsAttr
+mlir::linalg::detail::getStridesAttr(DepthwiseConvolutionOpInterface op) {
   auto maybeStridesAttr = op.getStridesAttr();
   maybeStridesAttr.dump();
   if (!maybeStridesAttr) {
     OpBuilder builder(op.getContext());
-    int64_t numSpatialDims = op.image().getType().cast<ShapedType>().getRank() - 2;
+    int64_t numSpatialDims =
+        op.image().getType().cast<ShapedType>().getRank() - 2;
     auto type = RankedTensorType::get({static_cast<int64_t>(numSpatialDims)},
-                                    builder.getI64Type());
+                                      builder.getI64Type());
     SmallVector<int64_t> strides(numSpatialDims, 1);
     return DenseIntElementsAttr::get(type, strides);
   }
   return op.getStridesAttr();
 }
 
-DenseIntElementsAttr mlir::linalg::detail::getDilationsAttr(DepthwiseConvolutionOpInterface op) {
+DenseIntElementsAttr
+mlir::linalg::detail::getDilationsAttr(DepthwiseConvolutionOpInterface op) {
   return op.getDilationsAttr();
 }
 
-BoolAttr mlir::linalg::detail::getChannelFirstAttr(DepthwiseConvolutionOpInterface op) {
+BoolAttr
+mlir::linalg::detail::getChannelFirstAttr(DepthwiseConvolutionOpInterface op) {
   return op.getChannelFirstAttr();
 }
 
-ArrayAttr mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface op) {
+ArrayAttr
+mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface op) {
   ArrayAttr cached = op->getAttrOfType<ArrayAttr>(
       LinalgDialect::kMemoizedIndexingMapsAttrName);
   if (cached)
@@ -670,16 +675,23 @@ ArrayAttr mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface
   auto numSpatial = op.image().getType().cast<ShapedType>().getRank() - 2;
   // Domain: (n, w, c, kw)
   AffineExpr n = getAffineDimExpr(0, ctx);
-  SmallVector<AffineExpr> s(llvm::map_range(llvm::seq<int64_t>(1, numSpatial + 1), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+  SmallVector<AffineExpr> s(
+      llvm::map_range(llvm::seq<int64_t>(1, numSpatial + 1),
+                      [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
   AffineExpr c = getAffineDimExpr(numSpatial + 1, ctx);
-  SmallVector<AffineExpr> ks(llvm::map_range(llvm::seq<int64_t>(numSpatial + 2, 2 * (numSpatial + 1)), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+  SmallVector<AffineExpr> ks(
+      llvm::map_range(llvm::seq<int64_t>(numSpatial + 2, 2 * (numSpatial + 1)),
+                      [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
   // Temp subsitute for channel position attr
-  int64_t channelPos = (op.getChannelFirstAttr().getValue()) ? 1 : numSpatial + 1;
-  
+  int64_t channelPos =
+      (op.getChannelFirstAttr().getValue()) ? 1 : numSpatial + 1;
+
   // Initialze operand accesses in nw order and insert c according to channel
   // position
   SmallVector<AffineExpr> inExprs = {n}, outExprs = {n};
-  for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, op.getStridesAttr().getValues<int64_t>(), op.getDilationsAttr().getValues<int64_t>())) {
+  for (const auto &[sp, ksp, st, di] :
+       llvm::zip(s, ks, op.getStridesAttr().getValues<int64_t>(),
+                 op.getDilationsAttr().getValues<int64_t>())) {
     inExprs.push_back(sp * st + ksp * di);
     outExprs.push_back(sp);
   }
@@ -689,11 +701,10 @@ ArrayAttr mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface
       channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
   outExprs.insert(outExprs.begin() + channelPos, c);
 
-
   cached = Builder(ctx).getAffineMapArrayAttr(
-      {AffineMap::get(4, 0, inExprs, ctx),
-       AffineMap::get(4, 0, kExprs, ctx),
-       AffineMap::get(4, 0, outExprs, ctx)});
+      {AffineMap::get(2 + 2 * numSpatial, 0, inExprs, ctx),
+       AffineMap::get(2 + 2 * numSpatial, 0, kExprs, ctx),
+       AffineMap::get(2 + 2 * numSpatial, 0, outExprs, ctx)});
   op->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
   return cached;
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6e904aa03c1eaf..2528a6d1971014 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1725,91 +1725,6 @@ SmallVector<utils::IteratorType> DepthwiseConv1DOp::getIteratorTypesArray() {
   ;
 }
 
-// ArrayAttr getNDIndexngMaps(DepthwiseConvolutionOpInterface op) {
-//   ArrayAttr cached = op->getAttrOfType<ArrayAttr>(
-//       LinalgDialect::kMemoizedIndexingMapsAttrName);
-//   if (cached)
-//     return cached;
-
-//   MLIRContext *ctx = op.getContext();
-//   auto numSpatial = op.image().getType().cast<ShapedType>().getRank() - 2;
-//   // Domain: (n, w, c, kw)
-//   AffineExpr n = getAffineDimExpr(0, ctx);
-//   SmallVector<AffineExpr> s(llvm::map_range(llvm::seq<int64_t>(1, numSpatial), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
-//   AffineExpr c = getAffineDimExpr(numSpatial + 1, ctx);
-//   SmallVector<AffineExpr> ks(llvm::map_range(llvm::seq<int64_t>(numSpatial + 2, 2 * (numSpatial + 1)), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
-
-//   // Temp subsitute for channel position attr
-//   int64_t channelPos = (1) ? 1 : numSpatial + 1;
-  
-//   // Initialze operand accesses in nw order and insert c according to channel
-//   // position
-//   SmallVector<AffineExpr> inExprs, outExprs = {n};
-//   for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, op.getStridesAttr().getValues<int64_t>(), op.getDilationsAttr().getValues<int64_t>())) {
-//     inExprs.push_back(sp * st + ksp * di);
-//     outExprs.push_back(sp);
-//   }
-//   SmallVector<AffineExpr> kExprs(ks);
-//   inExprs.insert(inExprs.begin() + channelPos, c);
-//   kExprs.insert(
-//       channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
-//   outExprs.insert(outExprs.begin() + channelPos, c);
-
-//   n.dump();
-//   for (auto sp : s)
-//     sp.dump();
-//   c.dump();
-//   for (auto ksp : ks)
-//     ksp.dump();
-
-//   for (auto b : inExprs)
-//     b.dump();
-
-//   cached = Builder(ctx).getAffineMapArrayAttr(
-//       {AffineMap::get(4, 0, inExprs, ctx),
-//        AffineMap::get(4, 0, kExprs, ctx),
-//        AffineMap::get(4, 0, outExprs, ctx)});
-//   op->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
-//   return cached;
-// }
-
-// ArrayAttr DepthwiseConv1DOp::getIndexingMaps() {
-//   return getNDIndexngMaps(this);
-//   // ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(
-//   //     LinalgDialect::kMemoizedIndexingMapsAttrName);
-//   // if (cached)
-//   //   return cached;
-
-//   // MLIRContext *context = getContext();
-
-//   // // Domain: (n, w, c, kw)
-//   // AffineExpr n = getAffineDimExpr(0, context);
-//   // AffineExpr w = getAffineDimExpr(1, context);
-//   // AffineExpr c = getAffineDimExpr(2, context);
-//   // AffineExpr kw = getAffineDimExpr(3, context);
-
-//   // // Temp subsitute for channel position attr
-//   // int64_t channelPos = (getChannelFirst()) ? 1 : 2;
-//   // // Initialze operand accesses in nw order and insert c according to channel
-//   // // position
-//   // SmallVector<AffineExpr> inExprs(
-//   //     {n, w * getStrides().getValues<int64_t>()[0] +
-//   //             kw * getDilations().getValues<int64_t>()[0]});
-//   // SmallVector<AffineExpr> kExprs({kw});
-//   // SmallVector<AffineExpr> outExprs({n, w});
-//   // inExprs.insert(inExprs.begin() + channelPos, c);
-//   // kExprs.insert(
-//   //     channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
-//   // outExprs.insert(outExprs.begin() + channelPos, c);
-
-//   // cached = Builder(context).getAffineMapArrayAttr(
-//   //     {AffineMap::get(4, 0, inExprs, context),
-//   //      AffineMap::get(4, 0, kExprs, context),
-//   //      AffineMap::get(4, 0, outExprs, context)});
-//   // getOperation()->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
-//   // return cached;
-// }
-
 void DepthwiseConv1DOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
@@ -1829,6 +1744,80 @@ void DepthwiseConv1DOp::print(OpAsmPrinter &p) {
 
 LogicalResult DepthwiseConv1DOp::verify() { return success(); }
 
+//===----------------------------------------------------------------------===//
+// DepthwiseConv2DOp
+//===----------------------------------------------------------------------===//
+
+// TODO: refactor into base implementation for all spatial dims
+void DepthwiseConv2DOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                                      ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "DepthwiseConv2DOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+  SmallVector<Value> yields;
+
+  Value value1 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(0));
+  Value value2 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
+void DepthwiseConv2DOp::build(
+    OpBuilder &builder, OperationState &result, ValueRange inputs,
+    ValueRange inits, bool channel_first,
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+    ArrayRef<NamedAttribute> attributes) {
+  build(builder, result, TypeRange{}, inputs, inits, channel_first);
+  result.addAttribute(getChannelFirstAttrName(result.name),
+                      builder.getBoolAttr(channel_first));
+  result.addAttributes(attributes);
+
+  // Add output types for `RankedTensorType` output arguments.
+  for (Value init : inits) {
+    Type initType = init.getType();
+    if (llvm::isa<RankedTensorType>(initType))
+      result.addTypes(initType);
+  }
+
+  if (bodyBuild)
+    buildGenericRegion(builder, result.location, *result.regions.front(),
+                       inputs, inits, bodyBuild);
+}
+
+SmallVector<utils::IteratorType> DepthwiseConv2DOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{
+      utils::IteratorType::parallel,  utils::IteratorType::parallel,
+      utils::IteratorType::parallel,  utils::IteratorType::parallel,
+      utils::IteratorType::reduction, utils::IteratorType::reduction};
+  ;
+}
+
+void DepthwiseConv2DOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+                        getDpsInits());
+}
+
+ParseResult DepthwiseConv2DOp::parse(OpAsmParser &parser,
+                                     OperationState &result) {
+  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+                                getRegionBuilder());
+}
+
+void DepthwiseConv2DOp::print(OpAsmPrinter &p) {
+  printNamedStructuredOp(p, getOperation(), getDpsInputs(), getDpsInits());
+}
+
+LogicalResult DepthwiseConv2DOp::verify() { return success(); }
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 4e9bd0d8118a6c..d23036e07f2223 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -902,6 +902,41 @@ func.func @gen_depthwise_channel_last_memref(%arg0: memref<?x?x?xf32>, %arg1: me
 //         COMMON:         %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
 //         COMMON:         store %[[res]], %[[arg2]][%[[n]], %[[ow]], %[[c]]] : memref<?x?x?xf32>
 
+func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
+    linalg.depthwise_conv_2d {channel_first = true} ins(%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?x?xf32>)
+    return
+}
+
+//    COMMON-LABEL: func @gen_depthwise_2D_channel_first_memref
+//    COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+//    COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//    COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+//         COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
+//         COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
+//         COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
+//         COMMON-DAG: %[[c3:.*]] = arith.constant 3 : index
+//         COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?x?xf32>
+//         COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c1]] : memref<?x?x?x?xf32>
+//         COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+//         COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+//         COMMON-DAG: %[[dim4:.*]] = memref.dim %[[arg2]], %[[c2]] : memref<?x?x?x?xf32>
+//         COMMON-DAG: %[[dim5:.*]] = memref.dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
+//          CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+//          CHECK:     scf.for %[[oh:.*]] = %[[c0]] to %[[dim4]] step %[[c1]] {
+//          CHECK:      scf.for %[[ow:.*]] = %[[c0]] to %[[dim5]] step %[[c1]] {
+//          CHECK:       scf.for %[[c:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
+//  CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[oh:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]])  to (%[[dim0]], %[[dim4]], %[[dim5]], %[[dim1]])
+//         COMMON:       scf.for %[[kh:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+//         COMMON:         scf.for %[[kw:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+//         COMMON:           %[[affh:.*]] = affine.apply #[[$stride1Dilation1]](%[[oh]], %[[kh]])
+//         COMMON:           %[[affw:.*]] = affine.apply #[[$stride1Dilation1]](%[[ow]], %[[kw]])
+//         COMMON:           %[[vb:.*]] = memref.load %[[arg0]][%[[n]], %[[c]], %[[affh]], %[[affw]]] : memref<?x?x?x?xf32>
+//         COMMON:           %[[va:.*]] = memref.load %[[arg1]][%[[c]], %[[kh]], %[[kw]]] : memref<?x?x?xf32>
+//         COMMON:           %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[c]], %[[oh]], %[[ow]]] : memref<?x?x?x?xf32>
+//         COMMON:           %[[inc:.*]] = arith.mulf %[[vb]], %[[va]] : f32
+//         COMMON:           %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
+//         COMMON:           store %[[res]], %[[arg2]][%[[n]], %[[c]], %[[oh]], %[[ow]]] : memref<?x?x?x?xf32>
+
 // -----
 
 func.func @lower_to_loops_with_rank_reducing_subviews(
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index f281c138727640..54075dbe36ab47 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -9,6 +9,15 @@ func.func @gen_depthwise_channel_first_memref(%arg0: memref<64x16x10xf32>, %arg1
 
 // -----
 
+// CHECK-LABEL: func @gen_depthwise_2D_channel_first_memref
+func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<64x16x10x10xf32>, %arg1: memref<16x3x3xf32>, %arg2: memref<64x16x8x8xf32>) {
+  // CHECK: depthwise_conv_2d {{.*}}channel_first = true
+    linalg.depthwise_conv_2d {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10xf32>, memref<16x3x3xf32>) outs(%arg2: memref<64x16x8x8xf32>)
+    return
+}
+
+// -----
+
 // CHECK-LABEL: func @gen_depthwise_channel_last_memref
 func.func @gen_depthwise_channel_last_memref(%arg0: memref<64x26x16xf32>, %arg1: memref<3x16xf32>, %arg2: memref<64x8x16xf32>) {
   // CHECK: depthwise_conv_1d {{.*}}channel_first = false 

>From 142c2e35f1e7980851289c80f48a815a27838177 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 8 Dec 2023 21:47:21 -0600
Subject: [PATCH 08/22] Refactor common methods to detail functions

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |   5 +-
 .../Dialect/Linalg/IR/LinalgInterfaces.td     | 101 +++++----
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  34 +--
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 197 ++++++++++++++++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  61 +-----
 mlir/test/Dialect/Linalg/loops.mlir           |  36 ++--
 6 files changed, 283 insertions(+), 151 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index c3eb3ff665c20d..a278d6509c19b6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -116,11 +116,14 @@ bool isaCopyOpInterface(LinalgOp linalgOp);
 
 namespace detail {
 
-// DictionaryAttr getStridesDict(DepthwiseConvolutionOpInterface op);
+// Common implementations for DepthwiseConvolutionOpInterface 
 DenseIntElementsAttr getStridesAttr(DepthwiseConvolutionOpInterface op);
 DenseIntElementsAttr getDilationsAttr(DepthwiseConvolutionOpInterface op);
 BoolAttr getChannelFirstAttr(DepthwiseConvolutionOpInterface op);
 ArrayAttr getIndexingMaps(DepthwiseConvolutionOpInterface op);
+ArrayAttr getIteratorTypes(DepthwiseConvolutionOpInterface op);
+void regionBuilder(ImplicitLocOpBuilder &b, Block &block,  ArrayRef<NamedAttribute> attrs);
+
 /// Returns true if the block contains a contraction of the following form:
 ///
 ///   %0 = <elemwise>(permutation-of(cu(block-argument-0),
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index e151d8c634dd87..6b05bc1e3f311f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -876,48 +876,63 @@ def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
   ]; 
 }
 
-def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpInterface", [LinalgConvolutionOpInterface]> {
-    let description = [{
-      A depthwise convolution is defined in general terms:
-      1. it is a convolution as defined by `ConvolutionOpInterface`
-      1. `in_channels = K * out_channels` for some integer `K`
-      4. The indexing maps of the input have expressions that satisfy
-      ```
-        AffineExpr ::== AffineDimExpr | ConvolvedExpr
-        ConvolvedExpr ::== MulExpr (`+` MulExpr)+
-        MulExpr ::== AffineDimExpr (`*` (AffineConstantExpr | AffineSymbolExpr))?
-      ```
-    }];
-    let cppNamespace = "::mlir::linalg";
-    let verify = [{ return detail::verifyConvolutionInterface($_op); }];
-    let methods = [
-      InterfaceMethod<[{
-        Returns strides attribute.
-      }],
-      "::mlir::DenseIntElementsAttr", "getStridesAttr", (ins),
-        /*methodBody=*/[{}],
-        /*defaultImplementation=*/[{
-          printf("whatever\n");
-          return detail::getStridesAttr($_op);
-        }]>,
-      InterfaceMethod<[{
-        Returns dilations attribute.
-      }],
-      "::mlir::DenseIntElementsAttr", "getDilationsAttr", (ins),
-        /*methodBody=*/[{}],
-        /*defaultImplementation=*/[{
-          printf("whatever\n");
-          return detail::getDilationsAttr($_op);
-        }]>,
-      InterfaceMethod<[{
-        Returns channel dim attribute.
-      }],
-      "::mlir::BoolAttr", "getChannelFirstAttr", (ins),
-        /*methodBody=*/[{}],
-        /*defaultImplementation=*/[{
-          return detail::getChannelFirstAttr($_op);
-        }]>
-    ];
-  }
+def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpInterface", [
+  LinalgConvolutionOpInterface]> {
+  let description = [{
+    A depthwise convolution is defined in general terms:
+    1. it is a convolution as defined by `ConvolutionOpInterface`
+    1. `in_channels = K * out_channels` for some integer `K`
+    4. The indexing maps of the input have expressions that satisfy
+    ```
+      AffineExpr ::== AffineDimExpr | ConvolvedExpr
+      ConvolvedExpr ::== MulExpr (`+` MulExpr)+
+      MulExpr ::== AffineDimExpr (`*` (AffineConstantExpr | AffineSymbolExpr))?
+    ```
+  }];
+  let cppNamespace = "::mlir::linalg";
+  let verify = [{ return detail::verifyConvolutionInterface($_op); }];
+  let methods = [
+    InterfaceMethod<[{
+      Returns strides attribute.
+    }],
+    "::mlir::DenseIntElementsAttr", "getStridesAttr", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return detail::getStridesAttr($_op);
+      }]>,
+    InterfaceMethod<[{
+      Returns dilations attribute.
+    }],
+    "::mlir::DenseIntElementsAttr", "getDilationsAttr", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return detail::getDilationsAttr($_op);
+      }]>,
+    InterfaceMethod<[{
+      Returns channel dim attribute.
+    }],
+    "::mlir::BoolAttr", "getChannelFirstAttr", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return detail::getChannelFirstAttr($_op);
+      }]>,
+    InterfaceMethod<[{
+      Returns indexing maps for any spatial dimension.
+    }],
+    "::mlir::ArrayAttr", "getIndexingMaps", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return detail::getIndexingMaps($_op);
+      }]>,
+    InterfaceMethod<[{
+      Returns indexing maps for any spatial dimension.
+    }],
+    "::mlir::ArrayAttr", "getIteratorTypes", (ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return detail::getIteratorTypes($_op);
+      }]>
+  ];
+}
 
 #endif // LINALG_IR_LINALGINTERFACES
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 866f1f1fde0d51..679718c5d7ee55 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -430,27 +430,14 @@ def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
         CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
     ];
     let hasCustomAssemblyFormat = 1;
-    let hasVerifier = 1;
 
     let extraClassDeclaration = structuredOpsBaseDecls # [{
-      // Declare functions necessary for LinalgStructuredInterface.
-      SmallVector<utils::IteratorType> getIteratorTypesArray();
-      // Figure out how to get implementation in `LinalgDepthwiseConvolutionOpInterface` working 
-      ArrayAttr getIndexingMaps() { return detail::getIndexingMaps(*this); };
-      std::string getLibraryCallName() {
-        return "op_has_no_registered_library_name";
-      }
-
-      static void regionBuilder(ImplicitLocOpBuilder &b,
-                              Block &block, ArrayRef<NamedAttribute> attrs);
-
-      // Implement functions necessary for DestinationStyleOpInterface.
       static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
                                 mlir::ArrayRef<mlir::NamedAttribute>)>
       getRegionBuilder() {
-        return regionBuilder;
+        return detail::regionBuilder;
       }
-      static unsigned getNumRegionArgs() { return 3; }
+      // Implement functions necessary for DestinationStyleOpInterface.
       MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
     }];
 }
@@ -497,27 +484,14 @@ def DepthwiseConv2DOp : LinalgStructuredBase_Op<"depthwise_conv_2d",
         CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
     ];
     let hasCustomAssemblyFormat = 1;
-    let hasVerifier = 1;
 
     let extraClassDeclaration = structuredOpsBaseDecls # [{
-      // Declare functions necessary for LinalgStructuredInterface.
-      SmallVector<utils::IteratorType> getIteratorTypesArray();
-      // Figure out how to get implementation in `LinalgDepthwiseConvolutionOpInterface` working 
-      ArrayAttr getIndexingMaps() { return detail::getIndexingMaps(*this); };
-      std::string getLibraryCallName() {
-        return "op_has_no_registered_library_name";
-      }
-
-      static void regionBuilder(ImplicitLocOpBuilder &b,
-                              Block &block, ArrayRef<NamedAttribute> attrs);
-
-      // Implement functions necessary for DestinationStyleOpInterface.
       static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
                                 mlir::ArrayRef<mlir::NamedAttribute>)>
       getRegionBuilder() {
-        return regionBuilder;
+        return detail::regionBuilder;
       }
-      static unsigned getNumRegionArgs() { return 3; }
+      // Implement functions necessary for DestinationStyleOpInterface.
       MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
     }];
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index cdd830c9d90a68..5686692f9da026 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -8,11 +8,13 @@
 
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 
+#include "mlir/AsmParser/AsmParser.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/AffineExprVisitor.h"
@@ -204,6 +206,171 @@ findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
 namespace {
 auto par = utils::IteratorType::parallel;
 auto red = utils::IteratorType::reduction;
+// TODO: Figure out a way to not copy this from LinalgOps.cpp
+class RegionBuilderHelper {
+public:
+  RegionBuilderHelper(MLIRContext *context, Block &block)
+      : context(context), block(block) {}
+
+  // Build the unary functions defined by OpDSL.
+  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
+    if (!isFloatingPoint(arg))
+      llvm_unreachable("unsupported non numeric type");
+    OpBuilder builder = getBuilder();
+    switch (unaryFn) {
+    case UnaryFn::exp:
+      return builder.create<math::ExpOp>(arg.getLoc(), arg);
+    case UnaryFn::log:
+      return builder.create<math::LogOp>(arg.getLoc(), arg);
+    case UnaryFn::abs:
+      return builder.create<math::AbsFOp>(arg.getLoc(), arg);
+    case UnaryFn::ceil:
+      return builder.create<math::CeilOp>(arg.getLoc(), arg);
+    case UnaryFn::floor:
+      return builder.create<math::FloorOp>(arg.getLoc(), arg);
+    case UnaryFn::negf:
+      return builder.create<arith::NegFOp>(arg.getLoc(), arg);
+    }
+    llvm_unreachable("unsupported unary function");
+  }
+
+  // Build the binary functions defined by OpDSL.
+  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
+    bool allComplex = isComplex(arg0) && isComplex(arg1);
+    bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
+    bool allInteger = isInteger(arg0) && isInteger(arg1);
+    bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
+                   arg1.getType().getIntOrFloatBitWidth() == 1;
+    if (!allComplex && !allFloatingPoint && !allInteger)
+      llvm_unreachable("unsupported non numeric type");
+    OpBuilder builder = getBuilder();
+    switch (binaryFn) {
+    case BinaryFn::add:
+      if (allComplex)
+        return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
+      if (allFloatingPoint)
+        return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
+      if (allBool)
+        return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
+      return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
+    case BinaryFn::sub:
+      if (allComplex)
+        return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
+      if (allFloatingPoint)
+        return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
+      if (allBool)
+        llvm_unreachable("unsupported operation: sub with bools");
+      return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
+    case BinaryFn::mul:
+      if (allComplex)
+        return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
+      if (allFloatingPoint)
+        return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
+      if (allBool)
+        return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
+      return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
+    case BinaryFn::div:
+      if (allComplex)
+        return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
+      if (allFloatingPoint)
+        return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
+      if (allBool)
+        llvm_unreachable("unsupported operation: div with bools");
+      return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
+    case BinaryFn::div_unsigned:
+      if (!allInteger || allBool)
+        llvm_unreachable("unsupported operation: unsigned div not on uint");
+      return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
+    case BinaryFn::max_signed:
+      assert(!allComplex);
+      if (allFloatingPoint)
+        return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
+      return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
+    case BinaryFn::min_signed:
+      assert(!allComplex);
+      if (allFloatingPoint)
+        return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
+      return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
+    case BinaryFn::max_unsigned:
+      assert(!allComplex);
+      if (allFloatingPoint)
+        return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
+      return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
+    case BinaryFn::min_unsigned:
+      assert(!allComplex);
+      if (allFloatingPoint)
+        return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
+      return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
+    }
+    llvm_unreachable("unsupported binary function");
+  }
+
+  // Build the type functions defined by OpDSL.
+  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
+    switch (typeFn) {
+    case TypeFn::cast_signed:
+      return cast(toType, operand, false);
+    case TypeFn::cast_unsigned:
+      return cast(toType, operand, true);
+    }
+    llvm_unreachable("unsupported type conversion function");
+  }
+
+  void yieldOutputs(ValueRange values) {
+    OpBuilder builder = getBuilder();
+    Location loc = builder.getUnknownLoc();
+    builder.create<YieldOp>(loc, values);
+  }
+
+  Value constant(const std::string &value) {
+    OpBuilder builder = getBuilder();
+    Location loc = builder.getUnknownLoc();
+    Attribute valueAttr = parseAttribute(value, builder.getContext());
+    return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
+  }
+
+  Value index(int64_t dim) {
+    OpBuilder builder = getBuilder();
+    return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
+  }
+
+  Type getIntegerType(unsigned width) {
+    return IntegerType::get(context, width);
+  }
+
+  Type getFloat32Type() { return Float32Type::get(context); }
+  Type getFloat64Type() { return Float64Type::get(context); }
+
+private:
+  // Generates operations to cast the given operand to a specified type.
+  // If the cast cannot be performed, a warning will be issued and the
+  // operand returned as-is (which will presumably yield a verification
+  // issue downstream).
+  Value cast(Type toType, Value operand, bool isUnsignedCast) {
+    OpBuilder builder = getBuilder();
+    auto loc = operand.getLoc();
+    return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
+  }
+
+  bool isComplex(Value value) {
+    return llvm::isa<ComplexType>(value.getType());
+  }
+  bool isFloatingPoint(Value value) {
+    return llvm::isa<FloatType>(value.getType());
+  }
+  bool isInteger(Value value) {
+    return llvm::isa<IntegerType>(value.getType());
+  }
+
+  OpBuilder getBuilder() {
+    OpBuilder builder(context);
+    builder.setInsertionPointToEnd(&block);
+    return builder;
+  }
+
+  MLIRContext *context;
+  Block █
+};
 } // namespace
 
 /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
@@ -664,6 +831,16 @@ mlir::linalg::detail::getChannelFirstAttr(DepthwiseConvolutionOpInterface op) {
   return op.getChannelFirstAttr();
 }
 
+ArrayAttr mlir::linalg::detail::getIteratorTypes(DepthwiseConvolutionOpInterface op) {
+  int64_t numSpatialDims =
+        op.image().getType().cast<ShapedType>().getRank() - 2;
+  SmallVector<Attribute> iteratorTypes(2 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par));
+  SmallVector<Attribute> reductions(numSpatialDims, IteratorTypeAttr::get(op.getContext(), red));
+  iteratorTypes.insert(iteratorTypes.end(), reductions.begin(), reductions.end());
+
+  return Builder(op.getContext()).getArrayAttr(iteratorTypes);
+}
+
 ArrayAttr
 mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface op) {
   ArrayAttr cached = op->getAttrOfType<ArrayAttr>(
@@ -709,6 +886,26 @@ mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface op) {
   return cached;
 }
 
+void mlir::linalg::detail::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                                         ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "DepthwiseConv1DOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+  SmallVector<Value> yields;
+
+  Value value1 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(0));
+  Value value2 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
 mlir::linalg::detail::MatchConvolutionResult
 mlir::linalg::detail::isConvolutionInterfaceImpl(
     Operation *op, ConvolutionDimensions *dimensions) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2528a6d1971014..b8663c7643f592 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1676,25 +1676,6 @@ LogicalResult ReduceOp::verify() {
 // DepthwiseConv1DOp
 //===----------------------------------------------------------------------===//
 
-void DepthwiseConv1DOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
-                                      ArrayRef<NamedAttribute> attrs) {
-  assert(block.getNumArguments() == 3 &&
-         "DepthwiseConv1DOp regionBuilder expects 3 (>=0) args");
-  RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
-  SmallVector<Value> yields;
-
-  Value value1 =
-      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
-                         block.getArgument(0));
-  Value value2 =
-      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
-                         block.getArgument(1));
-  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
-  Value value4 =
-      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
-  yields.push_back(value4);
-  helper.yieldOutputs(yields);
-}
 
 void DepthwiseConv1DOp::build(
     OpBuilder &builder, OperationState &result, ValueRange inputs,
@@ -1718,13 +1699,6 @@ void DepthwiseConv1DOp::build(
                        inputs, inits, bodyBuild);
 }
 
-SmallVector<utils::IteratorType> DepthwiseConv1DOp::getIteratorTypesArray() {
-  return SmallVector<utils::IteratorType>{
-      utils::IteratorType::parallel, utils::IteratorType::parallel,
-      utils::IteratorType::parallel, utils::IteratorType::reduction};
-  ;
-}
-
 void DepthwiseConv1DOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
@@ -1734,7 +1708,7 @@ void DepthwiseConv1DOp::getEffects(
 
 ParseResult DepthwiseConv1DOp::parse(OpAsmParser &parser,
                                      OperationState &result) {
-  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+  return parseNamedStructuredOp(parser, result, 3,
                                 getRegionBuilder());
 }
 
@@ -1742,32 +1716,11 @@ void DepthwiseConv1DOp::print(OpAsmPrinter &p) {
   printNamedStructuredOp(p, getOperation(), getDpsInputs(), getDpsInits());
 }
 
-LogicalResult DepthwiseConv1DOp::verify() { return success(); }
-
 //===----------------------------------------------------------------------===//
 // DepthwiseConv2DOp
 //===----------------------------------------------------------------------===//
 
 // TODO: refactor into base implementation for all spatial dims
-void DepthwiseConv2DOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
-                                      ArrayRef<NamedAttribute> attrs) {
-  assert(block.getNumArguments() == 3 &&
-         "DepthwiseConv2DOp regionBuilder expects 3 (>=0) args");
-  RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
-  SmallVector<Value> yields;
-
-  Value value1 =
-      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
-                         block.getArgument(0));
-  Value value2 =
-      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
-                         block.getArgument(1));
-  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
-  Value value4 =
-      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
-  yields.push_back(value4);
-  helper.yieldOutputs(yields);
-}
 
 void DepthwiseConv2DOp::build(
     OpBuilder &builder, OperationState &result, ValueRange inputs,
@@ -1791,14 +1744,6 @@ void DepthwiseConv2DOp::build(
                        inputs, inits, bodyBuild);
 }
 
-SmallVector<utils::IteratorType> DepthwiseConv2DOp::getIteratorTypesArray() {
-  return SmallVector<utils::IteratorType>{
-      utils::IteratorType::parallel,  utils::IteratorType::parallel,
-      utils::IteratorType::parallel,  utils::IteratorType::parallel,
-      utils::IteratorType::reduction, utils::IteratorType::reduction};
-  ;
-}
-
 void DepthwiseConv2DOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
@@ -1808,7 +1753,7 @@ void DepthwiseConv2DOp::getEffects(
 
 ParseResult DepthwiseConv2DOp::parse(OpAsmParser &parser,
                                      OperationState &result) {
-  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+  return parseNamedStructuredOp(parser, result, 3,
                                 getRegionBuilder());
 }
 
@@ -1816,8 +1761,6 @@ void DepthwiseConv2DOp::print(OpAsmPrinter &p) {
   printNamedStructuredOp(p, getOperation(), getDpsInputs(), getDpsInits());
 }
 
-LogicalResult DepthwiseConv2DOp::verify() { return success(); }
-
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index d23036e07f2223..40d5078fc1d5fc 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -882,17 +882,17 @@ func.func @gen_depthwise_channel_last_memref(%arg0: memref<?x?x?xf32>, %arg1: me
 //    COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
 //    COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
 //    COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
-//         COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
-//         COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
-//         COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
-//         COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?xf32>
-//         COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<?x?x?xf32>
-//         COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c0]] : memref<?x?xf32>
-//         COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
+//     COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
+//     COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
+//     COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
+//     COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?xf32>
+//     COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<?x?x?xf32>
+//     COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c0]] : memref<?x?xf32>
+//     COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
 //          CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
 //          CHECK:   scf.for %[[ow:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
 //          CHECK:     scf.for %[[c:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
-//  CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]])  to (%[[dim0]], %[[dim3]], %[[dim1]])
+//  CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim3]], %[[dim1]])
 //         COMMON:       scf.for %[[kw:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
 //         COMMON:         %[[aff:.*]] = affine.apply #[[$stride3Dilation2]](%[[ow]], %[[kw]])
 //         COMMON:         %[[vb:.*]] = memref.load %[[arg0]][%[[n]], %[[aff]], %[[c]]] : memref<?x?x?xf32>
@@ -911,16 +911,16 @@ func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<?x?x?x?xf32>, %ar
 //    COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
 //    COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
 //    COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
-//         COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
-//         COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
-//         COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
-//         COMMON-DAG: %[[c3:.*]] = arith.constant 3 : index
-//         COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?x?xf32>
-//         COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c1]] : memref<?x?x?x?xf32>
-//         COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
-//         COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
-//         COMMON-DAG: %[[dim4:.*]] = memref.dim %[[arg2]], %[[c2]] : memref<?x?x?x?xf32>
-//         COMMON-DAG: %[[dim5:.*]] = memref.dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
+//     COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
+//     COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
+//     COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
+//     COMMON-DAG: %[[c3:.*]] = arith.constant 3 : index
+//     COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?x?xf32>
+//     COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c1]] : memref<?x?x?x?xf32>
+//     COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+//     COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+//     COMMON-DAG: %[[dim4:.*]] = memref.dim %[[arg2]], %[[c2]] : memref<?x?x?x?xf32>
+//     COMMON-DAG: %[[dim5:.*]] = memref.dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
 //          CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
 //          CHECK:     scf.for %[[oh:.*]] = %[[c0]] to %[[dim4]] step %[[c1]] {
 //          CHECK:      scf.for %[[ow:.*]] = %[[c0]] to %[[dim5]] step %[[c1]] {

>From 44a5e434a16c9cd7c4258cf57675d07c400d91bf Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 02:41:55 -0600
Subject: [PATCH 09/22] More refactoring

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |  13 +-
 .../Dialect/Linalg/IR/LinalgInterfaces.td     |  30 ++-
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  36 +--
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 213 ++----------------
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 106 ++++-----
 5 files changed, 109 insertions(+), 289 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index a278d6509c19b6..1c0e66edeabbb1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -116,13 +116,22 @@ bool isaCopyOpInterface(LinalgOp linalgOp);
 
 namespace detail {
 
-// Common implementations for DepthwiseConvolutionOpInterface 
+// Common implementations for DepthwiseConvolutionOpInterface
+namespace depthwise_convolution_impl {
 DenseIntElementsAttr getStridesAttr(DepthwiseConvolutionOpInterface op);
 DenseIntElementsAttr getDilationsAttr(DepthwiseConvolutionOpInterface op);
 BoolAttr getChannelFirstAttr(DepthwiseConvolutionOpInterface op);
 ArrayAttr getIndexingMaps(DepthwiseConvolutionOpInterface op);
 ArrayAttr getIteratorTypes(DepthwiseConvolutionOpInterface op);
-void regionBuilder(ImplicitLocOpBuilder &b, Block &block,  ArrayRef<NamedAttribute> attrs);
+void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                   ArrayRef<NamedAttribute> attrs);
+void getEffects(
+    DepthwiseConvolutionOpInterface op,
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects);
+ParseResult parse(OpAsmParser &parser, OperationState &result);
+void print(DepthwiseConvolutionOpInterface op, OpAsmPrinter &p);
+} // namespace depthwise_convolution_impl
 
 /// Returns true if the block contains a contraction of the following form:
 ///
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 6b05bc1e3f311f..9373e8811f443f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -175,6 +175,26 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
         return $_op.getOperation()->getOperand(1);
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/"Return the init operand.",
+      /*retTy=*/"Value",
+      /*methodName=*/"init",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return $_op.getOperation()->getOperand(2);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the result tensor operand.",
+      /*retTy=*/"Value",
+      /*methodName=*/"result",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return $_op.getOperation()->getResult(2);
+      }]
+    >,
   ];
 }
 
@@ -898,7 +918,7 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
     "::mlir::DenseIntElementsAttr", "getStridesAttr", (ins),
       /*methodBody=*/[{}],
       /*defaultImplementation=*/[{
-        return detail::getStridesAttr($_op);
+        return detail::depthwise_convolution_impl::getStridesAttr($_op);
       }]>,
     InterfaceMethod<[{
       Returns dilations attribute.
@@ -906,7 +926,7 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
     "::mlir::DenseIntElementsAttr", "getDilationsAttr", (ins),
       /*methodBody=*/[{}],
       /*defaultImplementation=*/[{
-        return detail::getDilationsAttr($_op);
+        return detail::depthwise_convolution_impl::getDilationsAttr($_op);
       }]>,
     InterfaceMethod<[{
       Returns channel dim attribute.
@@ -914,7 +934,7 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
     "::mlir::BoolAttr", "getChannelFirstAttr", (ins),
       /*methodBody=*/[{}],
       /*defaultImplementation=*/[{
-        return detail::getChannelFirstAttr($_op);
+        return detail::depthwise_convolution_impl::getChannelFirstAttr($_op);
       }]>,
     InterfaceMethod<[{
       Returns indexing maps for any spatial dimension.
@@ -922,7 +942,7 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
     "::mlir::ArrayAttr", "getIndexingMaps", (ins),
       /*methodBody=*/[{}],
       /*defaultImplementation=*/[{
-        return detail::getIndexingMaps($_op);
+        return detail::depthwise_convolution_impl::getIndexingMaps($_op);
       }]>,
     InterfaceMethod<[{
       Returns indexing maps for any spatial dimension.
@@ -930,7 +950,7 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
     "::mlir::ArrayAttr", "getIteratorTypes", (ins),
       /*methodBody=*/[{}],
       /*defaultImplementation=*/[{
-        return detail::getIteratorTypes($_op);
+        return detail::depthwise_convolution_impl::getIteratorTypes($_op);
       }]>
   ];
 }
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 679718c5d7ee55..c4322dd2c7d536 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -423,19 +423,19 @@ def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
     let regions = (region AnyRegion:$region);
 
-    let builders = [
-      OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits,
-        "bool":$channel_first,
-        "function_ref<void(OpBuilder &, Location, ValueRange)>",
-        CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
-    ];
-    let hasCustomAssemblyFormat = 1;
-
+    // TODO: Figure out how to move this to the interface
     let extraClassDeclaration = structuredOpsBaseDecls # [{
+      void print(::mlir::OpAsmPrinter &printer) {
+        return detail::depthwise_convolution_impl::print(*this, printer);
+      }
+      static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
+                                        ::mlir::OperationState &result) {
+        return detail::depthwise_convolution_impl::parse(parser, result);
+      }
       static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
                                 mlir::ArrayRef<mlir::NamedAttribute>)>
       getRegionBuilder() {
-        return detail::regionBuilder;
+        return detail::depthwise_convolution_impl::regionBuilder;
       }
       // Implement functions necessary for DestinationStyleOpInterface.
       MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
@@ -477,19 +477,19 @@ def DepthwiseConv2DOp : LinalgStructuredBase_Op<"depthwise_conv_2d",
     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
     let regions = (region AnyRegion:$region);
 
-    let builders = [
-      OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits,
-        "bool":$channel_first,
-        "function_ref<void(OpBuilder &, Location, ValueRange)>",
-        CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
-    ];
-    let hasCustomAssemblyFormat = 1;
-
+    // TODO: Figure out how to move this to the interface
     let extraClassDeclaration = structuredOpsBaseDecls # [{
+      void print(::mlir::OpAsmPrinter &printer) {
+        return detail::depthwise_convolution_impl::print(*this, printer);
+      }
+      static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
+                                        ::mlir::OperationState &result) {
+        return detail::depthwise_convolution_impl::parse(parser, result);
+      }
       static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
                                 mlir::ArrayRef<mlir::NamedAttribute>)>
       getRegionBuilder() {
-        return detail::regionBuilder;
+        return detail::depthwise_convolution_impl::regionBuilder;
       }
       // Implement functions necessary for DestinationStyleOpInterface.
       MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 5686692f9da026..b6c073a8f11f26 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -206,171 +206,6 @@ findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
 namespace {
 auto par = utils::IteratorType::parallel;
 auto red = utils::IteratorType::reduction;
-// TODO: Figure out a way to not copy this from LinalgOps.cpp
-class RegionBuilderHelper {
-public:
-  RegionBuilderHelper(MLIRContext *context, Block &block)
-      : context(context), block(block) {}
-
-  // Build the unary functions defined by OpDSL.
-  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
-    if (!isFloatingPoint(arg))
-      llvm_unreachable("unsupported non numeric type");
-    OpBuilder builder = getBuilder();
-    switch (unaryFn) {
-    case UnaryFn::exp:
-      return builder.create<math::ExpOp>(arg.getLoc(), arg);
-    case UnaryFn::log:
-      return builder.create<math::LogOp>(arg.getLoc(), arg);
-    case UnaryFn::abs:
-      return builder.create<math::AbsFOp>(arg.getLoc(), arg);
-    case UnaryFn::ceil:
-      return builder.create<math::CeilOp>(arg.getLoc(), arg);
-    case UnaryFn::floor:
-      return builder.create<math::FloorOp>(arg.getLoc(), arg);
-    case UnaryFn::negf:
-      return builder.create<arith::NegFOp>(arg.getLoc(), arg);
-    }
-    llvm_unreachable("unsupported unary function");
-  }
-
-  // Build the binary functions defined by OpDSL.
-  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
-    bool allComplex = isComplex(arg0) && isComplex(arg1);
-    bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
-    bool allInteger = isInteger(arg0) && isInteger(arg1);
-    bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
-                   arg1.getType().getIntOrFloatBitWidth() == 1;
-    if (!allComplex && !allFloatingPoint && !allInteger)
-      llvm_unreachable("unsupported non numeric type");
-    OpBuilder builder = getBuilder();
-    switch (binaryFn) {
-    case BinaryFn::add:
-      if (allComplex)
-        return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
-      if (allFloatingPoint)
-        return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
-      if (allBool)
-        return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
-      return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
-    case BinaryFn::sub:
-      if (allComplex)
-        return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
-      if (allFloatingPoint)
-        return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
-      if (allBool)
-        llvm_unreachable("unsupported operation: sub with bools");
-      return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
-    case BinaryFn::mul:
-      if (allComplex)
-        return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
-      if (allFloatingPoint)
-        return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
-      if (allBool)
-        return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
-      return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
-    case BinaryFn::div:
-      if (allComplex)
-        return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
-      if (allFloatingPoint)
-        return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
-      if (allBool)
-        llvm_unreachable("unsupported operation: div with bools");
-      return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
-    case BinaryFn::div_unsigned:
-      if (!allInteger || allBool)
-        llvm_unreachable("unsupported operation: unsigned div not on uint");
-      return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
-    case BinaryFn::max_signed:
-      assert(!allComplex);
-      if (allFloatingPoint)
-        return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
-      return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
-    case BinaryFn::min_signed:
-      assert(!allComplex);
-      if (allFloatingPoint)
-        return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
-      return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
-    case BinaryFn::max_unsigned:
-      assert(!allComplex);
-      if (allFloatingPoint)
-        return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
-      return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
-    case BinaryFn::min_unsigned:
-      assert(!allComplex);
-      if (allFloatingPoint)
-        return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
-      return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
-    }
-    llvm_unreachable("unsupported binary function");
-  }
-
-  // Build the type functions defined by OpDSL.
-  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
-    switch (typeFn) {
-    case TypeFn::cast_signed:
-      return cast(toType, operand, false);
-    case TypeFn::cast_unsigned:
-      return cast(toType, operand, true);
-    }
-    llvm_unreachable("unsupported type conversion function");
-  }
-
-  void yieldOutputs(ValueRange values) {
-    OpBuilder builder = getBuilder();
-    Location loc = builder.getUnknownLoc();
-    builder.create<YieldOp>(loc, values);
-  }
-
-  Value constant(const std::string &value) {
-    OpBuilder builder = getBuilder();
-    Location loc = builder.getUnknownLoc();
-    Attribute valueAttr = parseAttribute(value, builder.getContext());
-    return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
-  }
-
-  Value index(int64_t dim) {
-    OpBuilder builder = getBuilder();
-    return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
-  }
-
-  Type getIntegerType(unsigned width) {
-    return IntegerType::get(context, width);
-  }
-
-  Type getFloat32Type() { return Float32Type::get(context); }
-  Type getFloat64Type() { return Float64Type::get(context); }
-
-private:
-  // Generates operations to cast the given operand to a specified type.
-  // If the cast cannot be performed, a warning will be issued and the
-  // operand returned as-is (which will presumably yield a verification
-  // issue downstream).
-  Value cast(Type toType, Value operand, bool isUnsignedCast) {
-    OpBuilder builder = getBuilder();
-    auto loc = operand.getLoc();
-    return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
-  }
-
-  bool isComplex(Value value) {
-    return llvm::isa<ComplexType>(value.getType());
-  }
-  bool isFloatingPoint(Value value) {
-    return llvm::isa<FloatType>(value.getType());
-  }
-  bool isInteger(Value value) {
-    return llvm::isa<IntegerType>(value.getType());
-  }
-
-  OpBuilder getBuilder() {
-    OpBuilder builder(context);
-    builder.setInsertionPointToEnd(&block);
-    return builder;
-  }
-
-  MLIRContext *context;
-  Block █
-};
 } // namespace
 
 /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
@@ -806,7 +641,8 @@ enum class MatchConvolutionResult {
 } // namespace mlir::linalg::detail
 
 DenseIntElementsAttr
-mlir::linalg::detail::getStridesAttr(DepthwiseConvolutionOpInterface op) {
+mlir::linalg::detail::depthwise_convolution_impl::getStridesAttr(
+    DepthwiseConvolutionOpInterface op) {
   auto maybeStridesAttr = op.getStridesAttr();
   maybeStridesAttr.dump();
   if (!maybeStridesAttr) {
@@ -822,27 +658,32 @@ mlir::linalg::detail::getStridesAttr(DepthwiseConvolutionOpInterface op) {
 }
 
 DenseIntElementsAttr
-mlir::linalg::detail::getDilationsAttr(DepthwiseConvolutionOpInterface op) {
+mlir::linalg::detail::depthwise_convolution_impl::getDilationsAttr(
+    DepthwiseConvolutionOpInterface op) {
   return op.getDilationsAttr();
 }
 
-BoolAttr
-mlir::linalg::detail::getChannelFirstAttr(DepthwiseConvolutionOpInterface op) {
+BoolAttr mlir::linalg::detail::depthwise_convolution_impl::getChannelFirstAttr(
+    DepthwiseConvolutionOpInterface op) {
   return op.getChannelFirstAttr();
 }
 
-ArrayAttr mlir::linalg::detail::getIteratorTypes(DepthwiseConvolutionOpInterface op) {
+ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIteratorTypes(
+    DepthwiseConvolutionOpInterface op) {
   int64_t numSpatialDims =
-        op.image().getType().cast<ShapedType>().getRank() - 2;
-  SmallVector<Attribute> iteratorTypes(2 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par));
-  SmallVector<Attribute> reductions(numSpatialDims, IteratorTypeAttr::get(op.getContext(), red));
-  iteratorTypes.insert(iteratorTypes.end(), reductions.begin(), reductions.end());
+      op.image().getType().cast<ShapedType>().getRank() - 2;
+  SmallVector<Attribute> iteratorTypes(
+      2 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par));
+  SmallVector<Attribute> reductions(
+      numSpatialDims, IteratorTypeAttr::get(op.getContext(), red));
+  iteratorTypes.insert(iteratorTypes.end(), reductions.begin(),
+                       reductions.end());
 
   return Builder(op.getContext()).getArrayAttr(iteratorTypes);
 }
 
-ArrayAttr
-mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface op) {
+ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
+    DepthwiseConvolutionOpInterface op) {
   ArrayAttr cached = op->getAttrOfType<ArrayAttr>(
       LinalgDialect::kMemoizedIndexingMapsAttrName);
   if (cached)
@@ -886,26 +727,6 @@ mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface op) {
   return cached;
 }
 
-void mlir::linalg::detail::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
-                                         ArrayRef<NamedAttribute> attrs) {
-  assert(block.getNumArguments() == 3 &&
-         "DepthwiseConv1DOp regionBuilder expects 3 (>=0) args");
-  RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
-  SmallVector<Value> yields;
-
-  Value value1 =
-      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
-                         block.getArgument(0));
-  Value value2 =
-      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
-                         block.getArgument(1));
-  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
-  Value value4 =
-      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
-  yields.push_back(value4);
-  helper.yieldOutputs(yields);
-}
-
 mlir::linalg::detail::MatchConvolutionResult
 mlir::linalg::detail::isConvolutionInterfaceImpl(
     Operation *op, ConvolutionDimensions *dimensions) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b8663c7643f592..76879536fe21c7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1673,92 +1673,62 @@ LogicalResult ReduceOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// DepthwiseConv1DOp
+// DepthwiseConvNDOp
 //===----------------------------------------------------------------------===//
 
-
-void DepthwiseConv1DOp::build(
-    OpBuilder &builder, OperationState &result, ValueRange inputs,
-    ValueRange inits, bool channel_first,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
-    ArrayRef<NamedAttribute> attributes) {
-  build(builder, result, TypeRange{}, inputs, inits, channel_first);
-  result.addAttribute(getChannelFirstAttrName(result.name),
-                      builder.getBoolAttr(channel_first));
-  result.addAttributes(attributes);
-
-  // Add output types for `RankedTensorType` output arguments.
-  for (Value init : inits) {
-    Type initType = init.getType();
-    if (llvm::isa<RankedTensorType>(initType))
-      result.addTypes(initType);
-  }
-
-  if (bodyBuild)
-    buildGenericRegion(builder, result.location, *result.regions.front(),
-                       inputs, inits, bodyBuild);
-}
-
-void DepthwiseConv1DOp::getEffects(
+// There must be a way to avoid defining the following 3 functions
+void mlir::linalg::detail::depthwise_convolution_impl::getEffects(
+    DepthwiseConvolutionOpInterface op,
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
-                        getDpsInits());
+  getGenericEffectsImpl(effects, op.result(), op.image(), op.filter());
 }
 
-ParseResult DepthwiseConv1DOp::parse(OpAsmParser &parser,
-                                     OperationState &result) {
-  return parseNamedStructuredOp(parser, result, 3,
-                                getRegionBuilder());
+ParseResult mlir::linalg::detail::depthwise_convolution_impl::parse(
+    OpAsmParser &parser, OperationState &result) {
+  return parseNamedStructuredOp(
+      parser, result, 3,
+      mlir::linalg::detail::depthwise_convolution_impl::regionBuilder);
 }
 
-void DepthwiseConv1DOp::print(OpAsmPrinter &p) {
-  printNamedStructuredOp(p, getOperation(), getDpsInputs(), getDpsInits());
+void mlir::linalg::detail::depthwise_convolution_impl::print(
+    DepthwiseConvolutionOpInterface op, OpAsmPrinter &p) {
+  printNamedStructuredOp(p, op.getOperation(), op.image(), op.filter());
 }
 
-//===----------------------------------------------------------------------===//
-// DepthwiseConv2DOp
-//===----------------------------------------------------------------------===//
-
-// TODO: refactor into base implementation for all spatial dims
+// Build {mul, add} region for convolution
+void mlir::linalg::detail::depthwise_convolution_impl::regionBuilder(
+    ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "DepthwiseConv1DOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+  SmallVector<Value> yields;
 
-void DepthwiseConv2DOp::build(
-    OpBuilder &builder, OperationState &result, ValueRange inputs,
-    ValueRange inits, bool channel_first,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
-    ArrayRef<NamedAttribute> attributes) {
-  build(builder, result, TypeRange{}, inputs, inits, channel_first);
-  result.addAttribute(getChannelFirstAttrName(result.name),
-                      builder.getBoolAttr(channel_first));
-  result.addAttributes(attributes);
-
-  // Add output types for `RankedTensorType` output arguments.
-  for (Value init : inits) {
-    Type initType = init.getType();
-    if (llvm::isa<RankedTensorType>(initType))
-      result.addTypes(initType);
-  }
-
-  if (bodyBuild)
-    buildGenericRegion(builder, result.location, *result.regions.front(),
-                       inputs, inits, bodyBuild);
+  Value value1 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(0));
+  Value value2 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
 }
 
-void DepthwiseConv2DOp::getEffects(
+// TODO: Figure out how to move this to interface
+void DepthwiseConv1DOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
   getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
                         getDpsInits());
 }
-
-ParseResult DepthwiseConv2DOp::parse(OpAsmParser &parser,
-                                     OperationState &result) {
-  return parseNamedStructuredOp(parser, result, 3,
-                                getRegionBuilder());
-}
-
-void DepthwiseConv2DOp::print(OpAsmPrinter &p) {
-  printNamedStructuredOp(p, getOperation(), getDpsInputs(), getDpsInits());
+void DepthwiseConv2DOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+                        getDpsInits());
 }
 
 //===----------------------------------------------------------------------===//

>From 93552bc8de16fc4748acee9b92a1b89294f9acf6 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 03:25:37 -0600
Subject: [PATCH 10/22] Add 3D depthwise (relatively easily after refactor)

---
 .../Dialect/Linalg/IR/LinalgInterfaces.td     |  5 +-
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 54 +++++++++++++++++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  6 +++
 mlir/test/Dialect/Linalg/named-ops.mlir       |  9 ++++
 4 files changed, 72 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 9373e8811f443f..8753833b81d20e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -900,8 +900,9 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
   LinalgConvolutionOpInterface]> {
   let description = [{
     A depthwise convolution is defined in general terms:
-    1. it is a convolution as defined by `ConvolutionOpInterface`
-    1. `in_channels = K * out_channels` for some integer `K`
+    1. it is a convolution as defined by `ConvolutionOpInterface`.
+    2. `in_channels = K * out_channels` for some integer `K`.
+    3. `input_rank == output_rank == kernel_rank + 1` (including batch dim in input and output)
     4. The indexing maps of the input have expressions that satisfy
     ```
       AffineExpr ::== AffineDimExpr | ConvolvedExpr
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index c4322dd2c7d536..ffa9cc0562aeab 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -496,6 +496,60 @@ def DepthwiseConv2DOp : LinalgStructuredBase_Op<"depthwise_conv_2d",
     }];
 }
 
+def DepthwiseConv3DOp : LinalgStructuredBase_Op<"depthwise_conv_3d",
+  [AttrSizedOperandSegments, LinalgDepthwiseConvolutionOpInterface]> {
+    
+  let summary = [{
+    Performs 3-D depthwise convolution with switchable channel position; either first or last.
+  }];
+  let description = [{
+    Domain: N, OD, OH, OW, C, KD, KH, KW
+
+    Layout of operands is determined by the `channel_first` `BoolAttr`:
+
+    `channel_first == true`:
+      Input: `NCDHW`
+      Kernel: `CDHW`
+      Output: `NCDHW`
+
+    `channel_first == false`:
+      Input: `NDHWC`
+      Kernel: `DHWC`
+      Output: `NDHWC`
+
+  }];
+
+    let arguments = (ins
+      Variadic<TensorOrMemref>:$inputs,
+      Variadic<TensorOrMemref>:$inits,
+      DefaultValuedAttr<BoolAttr, "true">:$channel_first,
+      DefaultValuedAttr<RankedI64ElementsAttr<[3]>,
+        "{ static_cast<int64_t>(1) }">:$strides,
+      DefaultValuedAttr<RankedI64ElementsAttr<[3]>,
+        "{ static_cast<int64_t>(1) }">:$dilations
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    // TODO: Figure out how to move this to the interface
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      void print(::mlir::OpAsmPrinter &printer) {
+        return detail::depthwise_convolution_impl::print(*this, printer);
+      }
+      static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
+                                        ::mlir::OperationState &result) {
+        return detail::depthwise_convolution_impl::parse(parser, result);
+      }
+      static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+                                mlir::ArrayRef<mlir::NamedAttribute>)>
+      getRegionBuilder() {
+        return detail::depthwise_convolution_impl::regionBuilder;
+      }
+      // Implement functions necessary for DestinationStyleOpInterface.
+      MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Transpose op.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 76879536fe21c7..23c7c4730a16aa 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1730,6 +1730,12 @@ void DepthwiseConv2DOp::getEffects(
   getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
                         getDpsInits());
 }
+void DepthwiseConv3DOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+                        getDpsInits());
+}
 
 //===----------------------------------------------------------------------===//
 // TransposeOp
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 54075dbe36ab47..26df1ea316e532 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -18,6 +18,15 @@ func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<64x16x10x10xf32>,
 
 // -----
 
+// CHECK-LABEL: func @gen_depthwise_3D_channel_first_memref
+func.func @gen_depthwise_3D_channel_first_memref(%arg0: memref<64x16x10x10xf32>, %arg1: memref<16x3x3xf32>, %arg2: memref<64x16x8x8xf32>) {
+  // CHECK: depthwise_conv_3d {{.*}}channel_first = true
+    linalg.depthwise_conv_3d {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10xf32>, memref<16x3x3xf32>) outs(%arg2: memref<64x16x8x8xf32>)
+    return
+}
+
+// -----
+
 // CHECK-LABEL: func @gen_depthwise_channel_last_memref
 func.func @gen_depthwise_channel_last_memref(%arg0: memref<64x26x16xf32>, %arg1: memref<3x16xf32>, %arg2: memref<64x8x16xf32>) {
   // CHECK: depthwise_conv_1d {{.*}}channel_first = false 

>From 123ffc7a757542fe82bda02b8e376c6237ac027d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 03:48:33 -0600
Subject: [PATCH 11/22] Add verifier for depthwise interface

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h      |  4 ++++
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.td     |  2 +-
 mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp    | 14 ++++++++++++++
 3 files changed, 19 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 1c0e66edeabbb1..ed6cce6376932b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -189,6 +189,10 @@ LogicalResult verifyContractionInterface(Operation *op);
 /// Verify that `op` conforms to the ConvolutionOpInterface.
 LogicalResult verifyConvolutionInterface(Operation *op);
 
+/// Verify that `op` conforms to the DepthwiseConvolutionOpInterface.
+LogicalResult verifyDepthwiseConvolutionInterface(Operation *op);
+
+LogicalResult verifyConvolutionInterface(Operation *op);
 /// Verify that `op` conforms to the FillOpInterface.
 LogicalResult verifyFillInterface(Operation *op);
 
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 8753833b81d20e..9749f301dfb153 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -911,7 +911,7 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
     ```
   }];
   let cppNamespace = "::mlir::linalg";
-  let verify = [{ return detail::verifyConvolutionInterface($_op); }];
+  let verify = [{ return detail::verifyDepthwiseConvolutionInterface($_op); }];
   let methods = [
     InterfaceMethod<[{
       Returns strides attribute.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index b6c073a8f11f26..6fa9fd92304d30 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -727,6 +727,20 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
   return cached;
 }
 
+LogicalResult mlir::linalg::detail::verifyDepthwiseConvolutionInterface(Operation *op) {
+  if (failed(verifyConvolutionInterface(op)))
+    return failure();
+  if (DepthwiseConvolutionOpInterface conv = dyn_cast<DepthwiseConvolutionOpInterface>(op)) {
+    const auto imageRank = conv.image().getType().cast<ShapedType>().getRank();
+    const auto kernelRank = conv.filter().getType().cast<ShapedType>().getRank();
+    const auto initRank = conv.init().getType().cast<ShapedType>().getRank();
+    if (imageRank != initRank || imageRank != kernelRank + 1)
+      return op->emitError("Rank relationship must be `in_rank == out_rank == kernel_rank + 1`");
+    return success();
+  }
+  return failure();
+}
+
 mlir::linalg::detail::MatchConvolutionResult
 mlir::linalg::detail::isConvolutionInterfaceImpl(
     Operation *op, ConvolutionDimensions *dimensions) {

>From b9be7566cf8da5bbd9812d9e3845f3d9c7a80a64 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 05:11:36 -0600
Subject: [PATCH 12/22] Refactor separate spatial dims into one

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |   1 -
 .../Dialect/Linalg/IR/LinalgInterfaces.td     |  41 +++--
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 141 +++---------------
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    |  40 ++---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  14 +-
 mlir/test/Dialect/Linalg/loops.mlir           |   6 +-
 mlir/test/Dialect/Linalg/named-ops.mlir       |  30 ++--
 7 files changed, 77 insertions(+), 196 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index ed6cce6376932b..23240052607973 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -120,7 +120,6 @@ namespace detail {
 namespace depthwise_convolution_impl {
 DenseIntElementsAttr getStridesAttr(DepthwiseConvolutionOpInterface op);
 DenseIntElementsAttr getDilationsAttr(DepthwiseConvolutionOpInterface op);
-BoolAttr getChannelFirstAttr(DepthwiseConvolutionOpInterface op);
 ArrayAttr getIndexingMaps(DepthwiseConvolutionOpInterface op);
 ArrayAttr getIteratorTypes(DepthwiseConvolutionOpInterface op);
 void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 9749f301dfb153..d524fefbd43d24 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -913,30 +913,6 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
   let cppNamespace = "::mlir::linalg";
   let verify = [{ return detail::verifyDepthwiseConvolutionInterface($_op); }];
   let methods = [
-    InterfaceMethod<[{
-      Returns strides attribute.
-    }],
-    "::mlir::DenseIntElementsAttr", "getStridesAttr", (ins),
-      /*methodBody=*/[{}],
-      /*defaultImplementation=*/[{
-        return detail::depthwise_convolution_impl::getStridesAttr($_op);
-      }]>,
-    InterfaceMethod<[{
-      Returns dilations attribute.
-    }],
-    "::mlir::DenseIntElementsAttr", "getDilationsAttr", (ins),
-      /*methodBody=*/[{}],
-      /*defaultImplementation=*/[{
-        return detail::depthwise_convolution_impl::getDilationsAttr($_op);
-      }]>,
-    InterfaceMethod<[{
-      Returns channel dim attribute.
-    }],
-    "::mlir::BoolAttr", "getChannelFirstAttr", (ins),
-      /*methodBody=*/[{}],
-      /*defaultImplementation=*/[{
-        return detail::depthwise_convolution_impl::getChannelFirstAttr($_op);
-      }]>,
     InterfaceMethod<[{
       Returns indexing maps for any spatial dimension.
     }],
@@ -954,6 +930,23 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
         return detail::depthwise_convolution_impl::getIteratorTypes($_op);
       }]>
   ];
+
+  let extraClassDeclaration = [{
+    // Returns channel first attribute.
+    bool getChannelFirst() {
+        return (*this)->getAttrOfType<BoolAttr>("channel_first").getValue();
+    }
+  }];
+  let extraSharedClassDeclaration = [{
+    // Returns strides attribute.
+    ::mlir::DenseIntElementsAttr getStridesAttr() {
+        return detail::depthwise_convolution_impl::getStridesAttr($_op);
+    }
+    // Returns dilations attribute.
+    ::mlir::DenseIntElementsAttr getDilationsAttr() {
+        return detail::depthwise_convolution_impl::getDilationsAttr($_op);
+    }
+  }];
 }
 
 #endif // LINALG_IR_LINALGINTERFACES
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ffa9cc0562aeab..c8e7746b17fac9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -388,80 +388,30 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
 // DepthwiseConvNDOp ops.
 //===----------------------------------------------------------------------===//
 
-def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
+def DepthwiseConvNDOp : LinalgStructuredBase_Op<"depthwise_conv_nd",
   [AttrSizedOperandSegments, LinalgDepthwiseConvolutionOpInterface]> {
     
   let summary = [{
-    Performs 1-D depthwise convolution with switchable channel position; either first or last.
+    Performs N-D depthwise convolution with switchable channel position; either first or last.
   }];
   let description = [{
-    Domain: N, OW, C, KW
-
-    Layout of operands is determined by the `channel_first` `BoolAttr`:
-
-    `channel_first == true`:
-      Input: `NCW`
-      Kernel: `CW`
-      Output: `NCW`
-
-    `channel_first == false`:
-      Input: `NWC`
-      Kernel: `WC`
-      Output: `NWC`
-
-  }];
-
-    let arguments = (ins
-      Variadic<TensorOrMemref>:$inputs,
-      Variadic<TensorOrMemref>:$inits,
-      DefaultValuedAttr<BoolAttr, "true">:$channel_first,
-      DefaultValuedAttr<RankedI64ElementsAttr<[1]>,
-        "{ static_cast<int64_t>(1) }">:$strides,
-      DefaultValuedAttr<RankedI64ElementsAttr<[1]>,
-        "{ static_cast<int64_t>(1) }">:$dilations
-    );
-    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
-    let regions = (region AnyRegion:$region);
-
-    // TODO: Figure out how to move this to the interface
-    let extraClassDeclaration = structuredOpsBaseDecls # [{
-      void print(::mlir::OpAsmPrinter &printer) {
-        return detail::depthwise_convolution_impl::print(*this, printer);
-      }
-      static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
-                                        ::mlir::OperationState &result) {
-        return detail::depthwise_convolution_impl::parse(parser, result);
-      }
-      static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-                                mlir::ArrayRef<mlir::NamedAttribute>)>
-      getRegionBuilder() {
-        return detail::depthwise_convolution_impl::regionBuilder;
-      }
-      // Implement functions necessary for DestinationStyleOpInterface.
-      MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
-    }];
-}
-
-def DepthwiseConv2DOp : LinalgStructuredBase_Op<"depthwise_conv_2d",
-  [AttrSizedOperandSegments, LinalgDepthwiseConvolutionOpInterface]> {
+    Allows any number of spatial dimensions but treats all of them as contiguous.  Throughout, `S`,
+    will represent all spatial dimensions.  Operand layouts are determined by the `channel_first`
+    `bool` attritbute.  When placing the channel dim first or last, the batch dim is excluded.  In
+    any case, the channel and spatial dims are in the same relative order for all operands.
     
-  let summary = [{
-    Performs 2-D depthwise convolution with switchable channel position; either first or last.
-  }];
-  let description = [{
-    Domain: N, OH, OW, C, KH, KW
-
-    Layout of operands is determined by the `channel_first` `BoolAttr`:
+    Domain: N, S, C, KS
 
-    `channel_first == true`:
-      Input: `NCHW`
-      Kernel: `CHW`
-      Output: `NCHW`
+    Layouts:
+      `channel_first == true`:
+        Input: `NCS`
+        Kernel: `CS`
+        Output: `NCS`
 
-    `channel_first == false`:
-      Input: `NHWC`
-      Kernel: `HWC`
-      Output: `NHWC`
+      `channel_first == false`:
+        Input: `NSC`
+        Kernel: `SC`
+        Output: `NSC`
 
   }];
 
@@ -469,10 +419,8 @@ def DepthwiseConv2DOp : LinalgStructuredBase_Op<"depthwise_conv_2d",
       Variadic<TensorOrMemref>:$inputs,
       Variadic<TensorOrMemref>:$inits,
       DefaultValuedAttr<BoolAttr, "true">:$channel_first,
-      DefaultValuedAttr<RankedI64ElementsAttr<[2]>,
-        "{ static_cast<int64_t>(1) }">:$strides,
-      DefaultValuedAttr<RankedI64ElementsAttr<[2]>,
-        "{ static_cast<int64_t>(1) }">:$dilations
+      OptionalAttr<I64ElementsAttr>:$strides,
+      OptionalAttr<I64ElementsAttr>:$dilations
     );
     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
     let regions = (region AnyRegion:$region);
@@ -496,59 +444,6 @@ def DepthwiseConv2DOp : LinalgStructuredBase_Op<"depthwise_conv_2d",
     }];
 }
 
-def DepthwiseConv3DOp : LinalgStructuredBase_Op<"depthwise_conv_3d",
-  [AttrSizedOperandSegments, LinalgDepthwiseConvolutionOpInterface]> {
-    
-  let summary = [{
-    Performs 3-D depthwise convolution with switchable channel position; either first or last.
-  }];
-  let description = [{
-    Domain: N, OD, OH, OW, C, KD, KH, KW
-
-    Layout of operands is determined by the `channel_first` `BoolAttr`:
-
-    `channel_first == true`:
-      Input: `NCDHW`
-      Kernel: `CDHW`
-      Output: `NCDHW`
-
-    `channel_first == false`:
-      Input: `NDHWC`
-      Kernel: `DHWC`
-      Output: `NDHWC`
-
-  }];
-
-    let arguments = (ins
-      Variadic<TensorOrMemref>:$inputs,
-      Variadic<TensorOrMemref>:$inits,
-      DefaultValuedAttr<BoolAttr, "true">:$channel_first,
-      DefaultValuedAttr<RankedI64ElementsAttr<[3]>,
-        "{ static_cast<int64_t>(1) }">:$strides,
-      DefaultValuedAttr<RankedI64ElementsAttr<[3]>,
-        "{ static_cast<int64_t>(1) }">:$dilations
-    );
-    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
-    let regions = (region AnyRegion:$region);
-
-    // TODO: Figure out how to move this to the interface
-    let extraClassDeclaration = structuredOpsBaseDecls # [{
-      void print(::mlir::OpAsmPrinter &printer) {
-        return detail::depthwise_convolution_impl::print(*this, printer);
-      }
-      static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
-                                        ::mlir::OperationState &result) {
-        return detail::depthwise_convolution_impl::parse(parser, result);
-      }
-      static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-                                mlir::ArrayRef<mlir::NamedAttribute>)>
-      getRegionBuilder() {
-        return detail::depthwise_convolution_impl::regionBuilder;
-      }
-      // Implement functions necessary for DestinationStyleOpInterface.
-      MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
-    }];
-}
 
 //===----------------------------------------------------------------------===//
 // Transpose op.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 6fa9fd92304d30..145516d8ab0922 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -8,13 +8,11 @@
 
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 
-#include "mlir/AsmParser/AsmParser.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/AffineExprVisitor.h"
@@ -643,8 +641,7 @@ enum class MatchConvolutionResult {
 DenseIntElementsAttr
 mlir::linalg::detail::depthwise_convolution_impl::getStridesAttr(
     DepthwiseConvolutionOpInterface op) {
-  auto maybeStridesAttr = op.getStridesAttr();
-  maybeStridesAttr.dump();
+  auto maybeStridesAttr = op->getAttrOfType<DenseIntElementsAttr>("strides");
   if (!maybeStridesAttr) {
     OpBuilder builder(op.getContext());
     int64_t numSpatialDims =
@@ -654,18 +651,24 @@ mlir::linalg::detail::depthwise_convolution_impl::getStridesAttr(
     SmallVector<int64_t> strides(numSpatialDims, 1);
     return DenseIntElementsAttr::get(type, strides);
   }
-  return op.getStridesAttr();
+  return maybeStridesAttr;
 }
 
 DenseIntElementsAttr
 mlir::linalg::detail::depthwise_convolution_impl::getDilationsAttr(
     DepthwiseConvolutionOpInterface op) {
-  return op.getDilationsAttr();
-}
-
-BoolAttr mlir::linalg::detail::depthwise_convolution_impl::getChannelFirstAttr(
-    DepthwiseConvolutionOpInterface op) {
-  return op.getChannelFirstAttr();
+  auto maybeDilationsAttr =
+      op->getAttrOfType<DenseIntElementsAttr>("dilations");
+  if (!maybeDilationsAttr) {
+    OpBuilder builder(op.getContext());
+    int64_t numSpatialDims =
+        op.image().getType().cast<ShapedType>().getRank() - 2;
+    auto type = RankedTensorType::get({static_cast<int64_t>(numSpatialDims)},
+                                      builder.getI64Type());
+    SmallVector<int64_t> strides(numSpatialDims, 1);
+    return DenseIntElementsAttr::get(type, strides);
+  }
+  return maybeDilationsAttr;
 }
 
 ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIteratorTypes(
@@ -701,8 +704,7 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
       llvm::map_range(llvm::seq<int64_t>(numSpatial + 2, 2 * (numSpatial + 1)),
                       [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
   // Temp subsitute for channel position attr
-  int64_t channelPos =
-      (op.getChannelFirstAttr().getValue()) ? 1 : numSpatial + 1;
+  int64_t channelPos = (op.getChannelFirst()) ? 1 : numSpatial + 1;
 
   // Initialze operand accesses in nw order and insert c according to channel
   // position
@@ -727,15 +729,19 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
   return cached;
 }
 
-LogicalResult mlir::linalg::detail::verifyDepthwiseConvolutionInterface(Operation *op) {
+LogicalResult
+mlir::linalg::detail::verifyDepthwiseConvolutionInterface(Operation *op) {
   if (failed(verifyConvolutionInterface(op)))
     return failure();
-  if (DepthwiseConvolutionOpInterface conv = dyn_cast<DepthwiseConvolutionOpInterface>(op)) {
+  if (DepthwiseConvolutionOpInterface conv =
+          dyn_cast<DepthwiseConvolutionOpInterface>(op)) {
     const auto imageRank = conv.image().getType().cast<ShapedType>().getRank();
-    const auto kernelRank = conv.filter().getType().cast<ShapedType>().getRank();
+    const auto kernelRank =
+        conv.filter().getType().cast<ShapedType>().getRank();
     const auto initRank = conv.init().getType().cast<ShapedType>().getRank();
     if (imageRank != initRank || imageRank != kernelRank + 1)
-      return op->emitError("Rank relationship must be `in_rank == out_rank == kernel_rank + 1`");
+      return op->emitError(
+          "Rank relationship must be `in_rank == out_rank == kernel_rank + 1`");
     return success();
   }
   return failure();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 23c7c4730a16aa..1a2935520604b0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1718,19 +1718,7 @@ void mlir::linalg::detail::depthwise_convolution_impl::regionBuilder(
 }
 
 // TODO: Figure out how to move this to interface
-void DepthwiseConv1DOp::getEffects(
-    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
-        &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
-                        getDpsInits());
-}
-void DepthwiseConv2DOp::getEffects(
-    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
-        &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
-                        getDpsInits());
-}
-void DepthwiseConv3DOp::getEffects(
+void DepthwiseConvNDOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
   getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 40d5078fc1d5fc..04e6c27f19dd81 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -844,7 +844,7 @@ func.func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32
 
 
 func.func @gen_depthwise_channel_first_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
-    linalg.depthwise_conv_1d {channel_first = true} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+    linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
     return
 }
 
@@ -874,7 +874,7 @@ func.func @gen_depthwise_channel_first_memref(%arg0: memref<?x?x?xf32>, %arg1: m
 
 
 func.func @gen_depthwise_channel_last_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
-    linalg.depthwise_conv_1d {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+    linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
     return
 }
 
@@ -903,7 +903,7 @@ func.func @gen_depthwise_channel_last_memref(%arg0: memref<?x?x?xf32>, %arg1: me
 //         COMMON:         store %[[res]], %[[arg2]][%[[n]], %[[ow]], %[[c]]] : memref<?x?x?xf32>
 
 func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
-    linalg.depthwise_conv_2d {channel_first = true} ins(%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?x?xf32>)
+    linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?x?xf32>)
     return
 }
 
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 26df1ea316e532..ed8ad2bf9e635f 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
-// CHECK-LABEL: func @gen_depthwise_channel_first_memref
-func.func @gen_depthwise_channel_first_memref(%arg0: memref<64x16x10xf32>, %arg1: memref<16x3xf32>, %arg2: memref<64x16x8xf32>) {
-  // CHECK: depthwise_conv_1d {{.*}}channel_first = true
-    linalg.depthwise_conv_1d {channel_first = true} ins(%arg0, %arg1: memref<64x16x10xf32>, memref<16x3xf32>) outs(%arg2: memref<64x16x8xf32>)
+// CHECK-LABEL: func @gen_depthwise_1D_channel_first_memref
+func.func @gen_depthwise_1D_channel_first_memref(%arg0: memref<64x16x10xf32>, %arg1: memref<16x3xf32>, %arg2: memref<64x16x8xf32>) {
+  // CHECK: depthwise_conv_nd {{.*}}channel_first = true
+    linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10xf32>, memref<16x3xf32>) outs(%arg2: memref<64x16x8xf32>)
     return
 }
 
@@ -11,17 +11,17 @@ func.func @gen_depthwise_channel_first_memref(%arg0: memref<64x16x10xf32>, %arg1
 
 // CHECK-LABEL: func @gen_depthwise_2D_channel_first_memref
 func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<64x16x10x10xf32>, %arg1: memref<16x3x3xf32>, %arg2: memref<64x16x8x8xf32>) {
-  // CHECK: depthwise_conv_2d {{.*}}channel_first = true
-    linalg.depthwise_conv_2d {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10xf32>, memref<16x3x3xf32>) outs(%arg2: memref<64x16x8x8xf32>)
+  // CHECK: depthwise_conv_nd {{.*}}channel_first = true
+    linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10xf32>, memref<16x3x3xf32>) outs(%arg2: memref<64x16x8x8xf32>)
     return
 }
 
 // -----
 
 // CHECK-LABEL: func @gen_depthwise_3D_channel_first_memref
-func.func @gen_depthwise_3D_channel_first_memref(%arg0: memref<64x16x10x10xf32>, %arg1: memref<16x3x3xf32>, %arg2: memref<64x16x8x8xf32>) {
-  // CHECK: depthwise_conv_3d {{.*}}channel_first = true
-    linalg.depthwise_conv_3d {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10xf32>, memref<16x3x3xf32>) outs(%arg2: memref<64x16x8x8xf32>)
+func.func @gen_depthwise_3D_channel_first_memref(%arg0: memref<64x16x10x10x10xf32>, %arg1: memref<16x3x3x3xf32>, %arg2: memref<64x16x8x8x8xf32>) {
+  // CHECK: depthwise_conv_nd {{.*}}channel_first = true
+    linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10x10xf32>, memref<16x3x3x3xf32>) outs(%arg2: memref<64x16x8x8x8xf32>)
     return
 }
 
@@ -29,8 +29,8 @@ func.func @gen_depthwise_3D_channel_first_memref(%arg0: memref<64x16x10x10xf32>,
 
 // CHECK-LABEL: func @gen_depthwise_channel_last_memref
 func.func @gen_depthwise_channel_last_memref(%arg0: memref<64x26x16xf32>, %arg1: memref<3x16xf32>, %arg2: memref<64x8x16xf32>) {
-  // CHECK: depthwise_conv_1d {{.*}}channel_first = false 
-    linalg.depthwise_conv_1d {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<64x26x16xf32>, memref<3x16xf32>) outs(%arg2: memref<64x8x16xf32>)
+  // CHECK: depthwise_conv_nd {{.*}}channel_first = false 
+    linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<64x26x16xf32>, memref<3x16xf32>) outs(%arg2: memref<64x8x16xf32>)
     return
 }
 
@@ -38,8 +38,8 @@ func.func @gen_depthwise_channel_last_memref(%arg0: memref<64x26x16xf32>, %arg1:
 
 // CHECK-LABEL: func @gen_depthwise_channel_first_tensor
 func.func @gen_depthwise_channel_first_tensor(%arg0: tensor<64x16x10xf32>, %arg1: tensor<16x3xf32>, %arg2: tensor<64x16x8xf32>) -> tensor<64x16x8xf32> {
-  // CHECK: depthwise_conv_1d {{.*}}channel_first = true
-    %0 = linalg.depthwise_conv_1d {channel_first = true} ins(%arg0, %arg1: tensor<64x16x10xf32>, tensor<16x3xf32>) outs(%arg2: tensor<64x16x8xf32>) -> tensor<64x16x8xf32> 
+  // CHECK: depthwise_conv_nd {{.*}}channel_first = true
+    %0 = linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: tensor<64x16x10xf32>, tensor<16x3xf32>) outs(%arg2: tensor<64x16x8xf32>) -> tensor<64x16x8xf32> 
     return %0 : tensor<64x16x8xf32>
 }
 
@@ -47,8 +47,8 @@ func.func @gen_depthwise_channel_first_tensor(%arg0: tensor<64x16x10xf32>, %arg1
 
 // CHECK-LABEL: func @gen_depthwise_channel_last_tensor
 func.func @gen_depthwise_channel_last_tensor(%arg0: tensor<64x26x16xf32>, %arg1: tensor<3x16xf32>, %arg2: tensor<64x8x16xf32>) -> tensor<64x8x16xf32> {
-  // CHECK: depthwise_conv_1d {{.*}}channel_first = false 
-    %0 = linalg.depthwise_conv_1d {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: tensor<64x26x16xf32>, tensor<3x16xf32>) outs(%arg2: tensor<64x8x16xf32>) -> tensor<64x8x16xf32>
+  // CHECK: depthwise_conv_nd {{.*}}channel_first = false 
+    %0 = linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: tensor<64x26x16xf32>, tensor<3x16xf32>) outs(%arg2: tensor<64x8x16xf32>) -> tensor<64x8x16xf32>
     return %0 : tensor<64x8x16xf32>
 }
 

>From e02cb893f3b51ec606ed1a797337e5d623483404 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 16:18:41 -0600
Subject: [PATCH 13/22] Fix a couple mistakes

---
 mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h  | 1 -
 mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 4 +++-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 23240052607973..f3ae787dccea69 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -191,7 +191,6 @@ LogicalResult verifyConvolutionInterface(Operation *op);
 /// Verify that `op` conforms to the DepthwiseConvolutionOpInterface.
 LogicalResult verifyDepthwiseConvolutionInterface(Operation *op);
 
-LogicalResult verifyConvolutionInterface(Operation *op);
 /// Verify that `op` conforms to the FillOpInterface.
 LogicalResult verifyFillInterface(Operation *op);
 
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index d524fefbd43d24..c5b40ba6826e44 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -192,7 +192,9 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return $_op.getOperation()->getResult(2);
+        if ($_op.getOperation()->getResults().empty())
+          return nullptr;
+        return $_op.getOperation()->getResult(0);
       }]
     >,
   ];

>From d715ace643ec9df565ad7d71730b5a23c4332c90 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 17:42:34 -0600
Subject: [PATCH 14/22] Fix print function

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1a2935520604b0..908a9817251694 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1686,14 +1686,16 @@ void mlir::linalg::detail::depthwise_convolution_impl::getEffects(
 
 ParseResult mlir::linalg::detail::depthwise_convolution_impl::parse(
     OpAsmParser &parser, OperationState &result) {
-  return parseNamedStructuredOp(
+  return ::parseNamedStructuredOp(
       parser, result, 3,
       mlir::linalg::detail::depthwise_convolution_impl::regionBuilder);
 }
 
 void mlir::linalg::detail::depthwise_convolution_impl::print(
     DepthwiseConvolutionOpInterface op, OpAsmPrinter &p) {
-  printNamedStructuredOp(p, op.getOperation(), op.image(), op.filter());
+  printNamedStructuredOp(p, op.getOperation(),
+                         ValueRange{op.image(), op.filter()},
+                         ValueRange{op.init()});
 }
 
 // Build {mul, add} region for convolution
@@ -1721,6 +1723,8 @@ void mlir::linalg::detail::depthwise_convolution_impl::regionBuilder(
 void DepthwiseConvNDOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
+  if (hasTensorSemantics())
+    return;
   getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
                         getDpsInits());
 }

>From 1e9822eb3f3eef1bea1c7455893f94b726db46d6 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 18:25:05 -0600
Subject: [PATCH 15/22] Add bufferization test

---
 mlir/test/Dialect/Linalg/bufferize.mlir | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 29f27e6838e661..c4625841baa5c3 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -217,3 +217,21 @@ func.func public @main(%arg0: tensor<2x3xi1>) -> tensor<6xi64> {
   %3 = func.call @csum(%2) : (tensor<6xi64>) -> tensor<6xi64>
   return %3 : tensor<6xi64>
 }
+
+// CHECK-LABEL:   func @gen_depthwise_3D_channel_first_tensor(
+// CHECK-SAME:                                   %[[ARG0_TENSOR:.*]]: tensor<64x16x26x26x26xf32>,
+// CHECK-SAME:                                   %[[ARG1_TENSOR:.*]]: tensor<16x3x3x3xf32>,
+// CHECK-SAME:                                   %[[ARG2_TENSOR:.*]]: tensor<64x16x8x8x8xf32>) -> tensor<64x16x8x8x8xf32> {
+// CHECK-DAG:       %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0_TENSOR]] : memref<64x16x26x26x26xf32>
+// CHECK-DAG:       %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<16x3x3x3xf32>
+// CHECK-DAG:       %[[ARG2_MEMREF:.*]] = bufferization.to_memref %[[ARG2_TENSOR]] : memref<64x16x8x8x8xf32>
+// CHECK-DAG:       %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<64x16x8x8x8xf32>
+// CHECK:           memref.copy %[[ARG2_MEMREF]], %[[INIT_BUFFER]] : memref<64x16x8x8x8xf32> to memref<64x16x8x8x8xf32>
+// CHECK:           linalg.depthwise_conv_nd
+// CHECK-SAME:      {channel_first = true, dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64>}
+// CHECK-SAME:      ins(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]] : memref<64x16x26x26x26xf32>, memref<16x3x3x3xf32>)
+// CHECK-SAME:      outs(%[[INIT_BUFFER]] : memref<64x16x8x8x8xf32>)
+func.func @gen_depthwise_3D_channel_first_tensor(%arg0: tensor<64x16x26x26x26xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<64x16x8x8x8xf32>) -> tensor<64x16x8x8x8xf32> {
+    %0 = linalg.depthwise_conv_nd {channel_first = true, strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x16x26x26x26xf32>, tensor<16x3x3x3xf32>) outs(%arg2: tensor<64x16x8x8x8xf32>) -> tensor<64x16x8x8x8xf32>
+    return %0 : tensor<64x16x8x8x8xf32> 
+}
\ No newline at end of file

>From ce0f47ea6986cf5e222bb22b56fa8777fe48889e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 20:25:35 -0600
Subject: [PATCH 16/22] Add multiplier dimension

---
 .../Dialect/Linalg/IR/LinalgInterfaces.td     | 11 +--
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 26 +++---
 mlir/test/Dialect/Linalg/bufferize.mlir       | 22 ++---
 mlir/test/Dialect/Linalg/loops.mlir           | 90 ++++++++++---------
 mlir/test/Dialect/Linalg/named-ops.mlir       | 28 +++---
 5 files changed, 92 insertions(+), 85 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index c5b40ba6826e44..39e0d05768e782 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -903,14 +903,9 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
   let description = [{
     A depthwise convolution is defined in general terms:
     1. it is a convolution as defined by `ConvolutionOpInterface`.
-    2. `in_channels = K * out_channels` for some integer `K`.
-    3. `input_rank == output_rank == kernel_rank + 1` (including batch dim in input and output)
-    4. The indexing maps of the input have expressions that satisfy
-    ```
-      AffineExpr ::== AffineDimExpr | ConvolvedExpr
-      ConvolvedExpr ::== MulExpr (`+` MulExpr)+
-      MulExpr ::== AffineDimExpr (`*` (AffineConstantExpr | AffineSymbolExpr))?
-    ```
+    2. `in_channels = K * out_channels` for some integer `m`.
+    3. The dimension of the filter preceding the channel dim is equal to `K`, the depth multiplier
+    4. `input_rank == kernel_rank == output_rank + 1` (including batch dim in input and output)
   }];
   let cppNamespace = "::mlir::linalg";
   let verify = [{ return detail::verifyDepthwiseConvolutionInterface($_op); }];
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 145516d8ab0922..33f4cd49eb530a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -676,7 +676,7 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIteratorTypes(
   int64_t numSpatialDims =
       op.image().getType().cast<ShapedType>().getRank() - 2;
   SmallVector<Attribute> iteratorTypes(
-      2 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par));
+      3 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par));
   SmallVector<Attribute> reductions(
       numSpatialDims, IteratorTypeAttr::get(op.getContext(), red));
   iteratorTypes.insert(iteratorTypes.end(), reductions.begin(),
@@ -694,14 +694,15 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
 
   MLIRContext *ctx = op.getContext();
   auto numSpatial = op.image().getType().cast<ShapedType>().getRank() - 2;
-  // Domain: (n, w, c, kw)
+  // Domain: (n, w, c, m, kw)
   AffineExpr n = getAffineDimExpr(0, ctx);
   SmallVector<AffineExpr> s(
       llvm::map_range(llvm::seq<int64_t>(1, numSpatial + 1),
                       [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
   AffineExpr c = getAffineDimExpr(numSpatial + 1, ctx);
+  AffineExpr m = getAffineDimExpr(numSpatial + 2, ctx);
   SmallVector<AffineExpr> ks(
-      llvm::map_range(llvm::seq<int64_t>(numSpatial + 2, 2 * (numSpatial + 1)),
+      llvm::map_range(llvm::seq<int64_t>(numSpatial + 3, 2 * (numSpatial + 1) + 1),
                       [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
   // Temp subsitute for channel position attr
   int64_t channelPos = (op.getChannelFirst()) ? 1 : numSpatial + 1;
@@ -709,6 +710,7 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
   // Initialze operand accesses in nw order and insert c according to channel
   // position
   SmallVector<AffineExpr> inExprs = {n}, outExprs = {n};
+  SmallVector<AffineExpr> cm = {c, m};
   for (const auto &[sp, ksp, st, di] :
        llvm::zip(s, ks, op.getStridesAttr().getValues<int64_t>(),
                  op.getDilationsAttr().getValues<int64_t>())) {
@@ -718,13 +720,13 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
   SmallVector<AffineExpr> kExprs(ks);
   inExprs.insert(inExprs.begin() + channelPos, c);
   kExprs.insert(
-      channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
-  outExprs.insert(outExprs.begin() + channelPos, c);
+      channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, cm.begin(), cm.end());
+  outExprs.insert(outExprs.begin() + channelPos, cm.begin(), cm.end());
 
   cached = Builder(ctx).getAffineMapArrayAttr(
-      {AffineMap::get(2 + 2 * numSpatial, 0, inExprs, ctx),
-       AffineMap::get(2 + 2 * numSpatial, 0, kExprs, ctx),
-       AffineMap::get(2 + 2 * numSpatial, 0, outExprs, ctx)});
+      {AffineMap::get(3 + 2 * numSpatial, 0, inExprs, ctx),
+       AffineMap::get(3 + 2 * numSpatial, 0, kExprs, ctx),
+       AffineMap::get(3 + 2 * numSpatial, 0, outExprs, ctx)});
   op->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
   return cached;
 }
@@ -735,11 +737,13 @@ mlir::linalg::detail::verifyDepthwiseConvolutionInterface(Operation *op) {
     return failure();
   if (DepthwiseConvolutionOpInterface conv =
           dyn_cast<DepthwiseConvolutionOpInterface>(op)) {
-    const auto imageRank = conv.image().getType().cast<ShapedType>().getRank();
+    const auto imageType = conv.image().getType().cast<ShapedType>();
+    const auto imageRank = imageType.getRank();
     const auto kernelRank =
         conv.filter().getType().cast<ShapedType>().getRank();
-    const auto initRank = conv.init().getType().cast<ShapedType>().getRank();
-    if (imageRank != initRank || imageRank != kernelRank + 1)
+    const auto initType = conv.init().getType().cast<ShapedType>();
+    const auto initRank = initType.getRank();
+    if (imageRank != kernelRank || imageRank != initRank - 1)
       return op->emitError(
           "Rank relationship must be `in_rank == out_rank == kernel_rank + 1`");
     return success();
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index c4625841baa5c3..3e04646a183907 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -220,18 +220,18 @@ func.func public @main(%arg0: tensor<2x3xi1>) -> tensor<6xi64> {
 
 // CHECK-LABEL:   func @gen_depthwise_3D_channel_first_tensor(
 // CHECK-SAME:                                   %[[ARG0_TENSOR:.*]]: tensor<64x16x26x26x26xf32>,
-// CHECK-SAME:                                   %[[ARG1_TENSOR:.*]]: tensor<16x3x3x3xf32>,
-// CHECK-SAME:                                   %[[ARG2_TENSOR:.*]]: tensor<64x16x8x8x8xf32>) -> tensor<64x16x8x8x8xf32> {
+// CHECK-SAME:                                   %[[ARG1_TENSOR:.*]]: tensor<16x2x3x3x3xf32>,
+// CHECK-SAME:                                   %[[ARG2_TENSOR:.*]]: tensor<64x16x2x8x8x8xf32>) -> tensor<64x16x2x8x8x8xf32> {
 // CHECK-DAG:       %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0_TENSOR]] : memref<64x16x26x26x26xf32>
-// CHECK-DAG:       %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<16x3x3x3xf32>
-// CHECK-DAG:       %[[ARG2_MEMREF:.*]] = bufferization.to_memref %[[ARG2_TENSOR]] : memref<64x16x8x8x8xf32>
-// CHECK-DAG:       %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<64x16x8x8x8xf32>
-// CHECK:           memref.copy %[[ARG2_MEMREF]], %[[INIT_BUFFER]] : memref<64x16x8x8x8xf32> to memref<64x16x8x8x8xf32>
+// CHECK-DAG:       %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<16x2x3x3x3xf32>
+// CHECK-DAG:       %[[ARG2_MEMREF:.*]] = bufferization.to_memref %[[ARG2_TENSOR]] : memref<64x16x2x8x8x8xf32>
+// CHECK-DAG:       %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<64x16x2x8x8x8xf32>
+// CHECK:           memref.copy %[[ARG2_MEMREF]], %[[INIT_BUFFER]] : memref<64x16x2x8x8x8xf32> to memref<64x16x2x8x8x8xf32>
 // CHECK:           linalg.depthwise_conv_nd
 // CHECK-SAME:      {channel_first = true, dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64>}
-// CHECK-SAME:      ins(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]] : memref<64x16x26x26x26xf32>, memref<16x3x3x3xf32>)
-// CHECK-SAME:      outs(%[[INIT_BUFFER]] : memref<64x16x8x8x8xf32>)
-func.func @gen_depthwise_3D_channel_first_tensor(%arg0: tensor<64x16x26x26x26xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<64x16x8x8x8xf32>) -> tensor<64x16x8x8x8xf32> {
-    %0 = linalg.depthwise_conv_nd {channel_first = true, strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x16x26x26x26xf32>, tensor<16x3x3x3xf32>) outs(%arg2: tensor<64x16x8x8x8xf32>) -> tensor<64x16x8x8x8xf32>
-    return %0 : tensor<64x16x8x8x8xf32> 
+// CHECK-SAME:      ins(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]] : memref<64x16x26x26x26xf32>, memref<16x2x3x3x3xf32>)
+// CHECK-SAME:      outs(%[[INIT_BUFFER]] : memref<64x16x2x8x8x8xf32>)
+func.func @gen_depthwise_3D_channel_first_tensor(%arg0: tensor<64x16x26x26x26xf32>, %arg1: tensor<16x2x3x3x3xf32>, %arg2: tensor<64x16x2x8x8x8xf32>) -> tensor<64x16x2x8x8x8xf32> {
+    %0 = linalg.depthwise_conv_nd {channel_first = true, strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x16x26x26x26xf32>, tensor<16x2x3x3x3xf32>) outs(%arg2: tensor<64x16x2x8x8x8xf32>) -> tensor<64x16x2x8x8x8xf32>
+    return %0 : tensor<64x16x2x8x8x8xf32> 
 }
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 04e6c27f19dd81..fb2e5c34331dc8 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -843,99 +843,107 @@ func.func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32
 //       CHECKPARALLEL:         store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
 
 
-func.func @gen_depthwise_channel_first_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
-    linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+func.func @gen_depthwise_channel_first_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
+    linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?x?xf32>)
     return
 }
 
 //    COMMON-LABEL: func @gen_depthwise_channel_first_memref
 //    COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
-//    COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
-//    COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//    COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//    COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
 //         COMMON: %[[c0:.*]] = arith.constant 0 : index
 //         COMMON: %[[c1:.*]] = arith.constant 1 : index
 //         COMMON: %[[c2:.*]] = arith.constant 2 : index
+//         COMMON: %[[c3:.*]] = arith.constant 3 : index
 //         COMMON: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?xf32>
 //         COMMON: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c1]] : memref<?x?x?xf32>
-//         COMMON: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?xf32>
-//         COMMON: %[[dim3:.*]] = memref.dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
+//         COMMON: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+//         COMMON: %[[dim3:.*]] = memref.dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+//         COMMON: %[[dim4:.*]] = memref.dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
 //          CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
-//          CHECK:   scf.for %[[ow:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+//          CHECK:   scf.for %[[ow:.*]] = %[[c0]] to %[[dim4]] step %[[c1]] {
 //          CHECK:     scf.for %[[c:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
-//  CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]])  to (%[[dim0]], %[[dim3]], %[[dim1]])
-//         COMMON:       scf.for %[[kw:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+//          CHECK:       scf.for %[[m:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+//  CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]], %[[m:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]])  to (%[[dim0]], %[[dim4]], %[[dim1]], %[[dim2]])
+//         COMMON:       scf.for %[[kw:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
 //         COMMON:         %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[ow]], %[[kw]])
 //         COMMON:         %[[vb:.*]] = memref.load %[[arg0]][%[[n]], %[[c]], %[[aff]]] : memref<?x?x?xf32>
-//         COMMON:         %[[va:.*]] = memref.load %[[arg1]][%[[c]], %[[kw]]] : memref<?x?xf32>
-//         COMMON:         %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[c]], %[[ow]]] : memref<?x?x?xf32>
+//         COMMON:         %[[va:.*]] = memref.load %[[arg1]][%[[c]], %[[m]], %[[kw]]] : memref<?x?x?xf32>
+//         COMMON:         %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[c]], %[[m]], %[[ow]]] : memref<?x?x?x?xf32>
 //         COMMON:         %[[inc:.*]] = arith.mulf %[[vb]], %[[va]] : f32
 //         COMMON:         %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
-//         COMMON:         store %[[res]], %[[arg2]][%[[n]], %[[c]], %[[ow]]] : memref<?x?x?xf32>
+//         COMMON:         store %[[res]], %[[arg2]][%[[n]], %[[c]], %[[m]], %[[ow]]] : memref<?x?x?x?xf32>
 
 
-func.func @gen_depthwise_channel_last_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
-    linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+func.func @gen_depthwise_channel_last_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
+    linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?x?xf32>)
     return
 }
 
 //    COMMON-LABEL: func @gen_depthwise_channel_last_memref
 //    COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
-//    COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
-//    COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//    COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//    COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
 //     COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
 //     COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
 //     COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
-//     COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?xf32>
-//     COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<?x?x?xf32>
-//     COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c0]] : memref<?x?xf32>
-//     COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
+//         COMMON: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?xf32>
+//         COMMON: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<?x?x?xf32>
+//         COMMON: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
+//         COMMON: %[[dim3:.*]] = memref.dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+//         COMMON: %[[dim4:.*]] = memref.dim %[[arg2]], %[[c1]] : memref<?x?x?x?xf32>
 //          CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
-//          CHECK:   scf.for %[[ow:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+//          CHECK:   scf.for %[[ow:.*]] = %[[c0]] to %[[dim4]] step %[[c1]] {
 //          CHECK:     scf.for %[[c:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
-//  CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim3]], %[[dim1]])
+//          CHECK:       scf.for %[[m:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+//  CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]], %[[m:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim4]], %[[dim1]], %[[dim3]])
 //         COMMON:       scf.for %[[kw:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
 //         COMMON:         %[[aff:.*]] = affine.apply #[[$stride3Dilation2]](%[[ow]], %[[kw]])
 //         COMMON:         %[[vb:.*]] = memref.load %[[arg0]][%[[n]], %[[aff]], %[[c]]] : memref<?x?x?xf32>
-//         COMMON:         %[[va:.*]] = memref.load %[[arg1]][%[[kw]], %[[c]]] : memref<?x?xf32>
-//         COMMON:         %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[ow]], %[[c]]] : memref<?x?x?xf32>
+//         COMMON:         %[[va:.*]] = memref.load %[[arg1]][%[[kw]], %[[c]], %[[m]]] : memref<?x?x?xf32>
+//         COMMON:         %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[ow]], %[[c]], %[[m]]] : memref<?x?x?x?xf32>
 //         COMMON:         %[[inc:.*]] = arith.mulf %[[vb]], %[[va]] : f32
 //         COMMON:         %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
-//         COMMON:         store %[[res]], %[[arg2]][%[[n]], %[[ow]], %[[c]]] : memref<?x?x?xf32>
+//         COMMON:         store %[[res]], %[[arg2]][%[[n]], %[[ow]], %[[c]], %[[m]]] : memref<?x?x?x?xf32>
 
-func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
-    linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?x?xf32>)
+func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?x?xf32>) {
+    linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>) outs(%arg2: memref<?x?x?x?x?xf32>)
     return
 }
 
 //    COMMON-LABEL: func @gen_depthwise_2D_channel_first_memref
 //    COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
-//    COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
-//    COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+//    COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+//    COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?x?xf32>
 //     COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
 //     COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
 //     COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
 //     COMMON-DAG: %[[c3:.*]] = arith.constant 3 : index
+//     COMMON-DAG: %[[c4:.*]] = arith.constant 4 : index
 //     COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?x?xf32>
 //     COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c1]] : memref<?x?x?x?xf32>
-//     COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
-//     COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
-//     COMMON-DAG: %[[dim4:.*]] = memref.dim %[[arg2]], %[[c2]] : memref<?x?x?x?xf32>
-//     COMMON-DAG: %[[dim5:.*]] = memref.dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
+//     COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?x?x?xf32>
+//     COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg1]], %[[c2]] : memref<?x?x?x?xf32>
+//     COMMON-DAG: %[[dim4:.*]] = memref.dim %[[arg1]], %[[c3]] : memref<?x?x?x?xf32>
+//     COMMON-DAG: %[[dim5:.*]] = memref.dim %[[arg2]], %[[c3]] : memref<?x?x?x?x?xf32>
+//     COMMON-DAG: %[[dim6:.*]] = memref.dim %[[arg2]], %[[c4]] : memref<?x?x?x?x?xf32>
 //          CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
-//          CHECK:     scf.for %[[oh:.*]] = %[[c0]] to %[[dim4]] step %[[c1]] {
-//          CHECK:      scf.for %[[ow:.*]] = %[[c0]] to %[[dim5]] step %[[c1]] {
+//          CHECK:     scf.for %[[oh:.*]] = %[[c0]] to %[[dim5]] step %[[c1]] {
+//          CHECK:      scf.for %[[ow:.*]] = %[[c0]] to %[[dim6]] step %[[c1]] {
 //          CHECK:       scf.for %[[c:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
-//  CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[oh:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]])  to (%[[dim0]], %[[dim4]], %[[dim5]], %[[dim1]])
-//         COMMON:       scf.for %[[kh:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
-//         COMMON:         scf.for %[[kw:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+//          CHECK:         scf.for %[[m:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+//  CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[oh:.*]], %[[ow:.*]], %[[c:.*]], %[[m:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]])  to (%[[dim0]], %[[dim5]], %[[dim6]], %[[dim1]], %[[dim2]])
+//         COMMON:       scf.for %[[kh:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+//         COMMON:         scf.for %[[kw:.*]] = %[[c0]] to %[[dim4]] step %[[c1]] {
 //         COMMON:           %[[affh:.*]] = affine.apply #[[$stride1Dilation1]](%[[oh]], %[[kh]])
 //         COMMON:           %[[affw:.*]] = affine.apply #[[$stride1Dilation1]](%[[ow]], %[[kw]])
 //         COMMON:           %[[vb:.*]] = memref.load %[[arg0]][%[[n]], %[[c]], %[[affh]], %[[affw]]] : memref<?x?x?x?xf32>
-//         COMMON:           %[[va:.*]] = memref.load %[[arg1]][%[[c]], %[[kh]], %[[kw]]] : memref<?x?x?xf32>
-//         COMMON:           %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[c]], %[[oh]], %[[ow]]] : memref<?x?x?x?xf32>
+//         COMMON:           %[[va:.*]] = memref.load %[[arg1]][%[[c]], %[[m]], %[[kh]], %[[kw]]] : memref<?x?x?x?xf32>
+//         COMMON:           %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[c]], %[[m]], %[[oh]], %[[ow]]] : memref<?x?x?x?x?xf32>
 //         COMMON:           %[[inc:.*]] = arith.mulf %[[vb]], %[[va]] : f32
 //         COMMON:           %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
-//         COMMON:           store %[[res]], %[[arg2]][%[[n]], %[[c]], %[[oh]], %[[ow]]] : memref<?x?x?x?xf32>
+//         COMMON:           store %[[res]], %[[arg2]][%[[n]], %[[c]], %[[m]], %[[oh]], %[[ow]]] : memref<?x?x?x?x?xf32>
 
 // -----
 
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index ed8ad2bf9e635f..21191ab1cb6f3b 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1,55 +1,55 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
 // CHECK-LABEL: func @gen_depthwise_1D_channel_first_memref
-func.func @gen_depthwise_1D_channel_first_memref(%arg0: memref<64x16x10xf32>, %arg1: memref<16x3xf32>, %arg2: memref<64x16x8xf32>) {
+func.func @gen_depthwise_1D_channel_first_memref(%arg0: memref<64x16x10xf32>, %arg1: memref<16x1x3xf32>, %arg2: memref<64x16x1x8xf32>) {
   // CHECK: depthwise_conv_nd {{.*}}channel_first = true
-    linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10xf32>, memref<16x3xf32>) outs(%arg2: memref<64x16x8xf32>)
+    linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10xf32>, memref<16x1x3xf32>) outs(%arg2: memref<64x16x1x8xf32>)
     return
 }
 
 // -----
 
 // CHECK-LABEL: func @gen_depthwise_2D_channel_first_memref
-func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<64x16x10x10xf32>, %arg1: memref<16x3x3xf32>, %arg2: memref<64x16x8x8xf32>) {
+func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<64x16x10x10xf32>, %arg1: memref<16x1x3x3xf32>, %arg2: memref<64x16x1x8x8xf32>) {
   // CHECK: depthwise_conv_nd {{.*}}channel_first = true
-    linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10xf32>, memref<16x3x3xf32>) outs(%arg2: memref<64x16x8x8xf32>)
+    linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10xf32>, memref<16x1x3x3xf32>) outs(%arg2: memref<64x16x1x8x8xf32>)
     return
 }
 
 // -----
 
 // CHECK-LABEL: func @gen_depthwise_3D_channel_first_memref
-func.func @gen_depthwise_3D_channel_first_memref(%arg0: memref<64x16x10x10x10xf32>, %arg1: memref<16x3x3x3xf32>, %arg2: memref<64x16x8x8x8xf32>) {
+func.func @gen_depthwise_3D_channel_first_memref(%arg0: memref<64x16x10x10x10xf32>, %arg1: memref<16x1x3x3x3xf32>, %arg2: memref<64x16x1x8x8x8xf32>) {
   // CHECK: depthwise_conv_nd {{.*}}channel_first = true
-    linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10x10xf32>, memref<16x3x3x3xf32>) outs(%arg2: memref<64x16x8x8x8xf32>)
+    linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10x10xf32>, memref<16x1x3x3x3xf32>) outs(%arg2: memref<64x16x1x8x8x8xf32>)
     return
 }
 
 // -----
 
 // CHECK-LABEL: func @gen_depthwise_channel_last_memref
-func.func @gen_depthwise_channel_last_memref(%arg0: memref<64x26x16xf32>, %arg1: memref<3x16xf32>, %arg2: memref<64x8x16xf32>) {
+func.func @gen_depthwise_channel_last_memref(%arg0: memref<64x26x16xf32>, %arg1: memref<3x16x1xf32>, %arg2: memref<64x8x16x1xf32>) {
   // CHECK: depthwise_conv_nd {{.*}}channel_first = false 
-    linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<64x26x16xf32>, memref<3x16xf32>) outs(%arg2: memref<64x8x16xf32>)
+    linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<64x26x16xf32>, memref<3x16x1xf32>) outs(%arg2: memref<64x8x16x1xf32>)
     return
 }
 
 // -----
 
 // CHECK-LABEL: func @gen_depthwise_channel_first_tensor
-func.func @gen_depthwise_channel_first_tensor(%arg0: tensor<64x16x10xf32>, %arg1: tensor<16x3xf32>, %arg2: tensor<64x16x8xf32>) -> tensor<64x16x8xf32> {
+func.func @gen_depthwise_channel_first_tensor(%arg0: tensor<64x16x10xf32>, %arg1: tensor<16x1x3xf32>, %arg2: tensor<64x16x1x8xf32>) -> tensor<64x16x1x8xf32> {
   // CHECK: depthwise_conv_nd {{.*}}channel_first = true
-    %0 = linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: tensor<64x16x10xf32>, tensor<16x3xf32>) outs(%arg2: tensor<64x16x8xf32>) -> tensor<64x16x8xf32> 
-    return %0 : tensor<64x16x8xf32>
+    %0 = linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: tensor<64x16x10xf32>, tensor<16x1x3xf32>) outs(%arg2: tensor<64x16x1x8xf32>) -> tensor<64x16x1x8xf32> 
+    return %0 : tensor<64x16x1x8xf32>
 }
 
 // -----
 
 // CHECK-LABEL: func @gen_depthwise_channel_last_tensor
-func.func @gen_depthwise_channel_last_tensor(%arg0: tensor<64x26x16xf32>, %arg1: tensor<3x16xf32>, %arg2: tensor<64x8x16xf32>) -> tensor<64x8x16xf32> {
+func.func @gen_depthwise_channel_last_tensor(%arg0: tensor<64x26x16xf32>, %arg1: tensor<3x16x1xf32>, %arg2: tensor<64x8x16x1xf32>) -> tensor<64x8x16x1xf32> {
   // CHECK: depthwise_conv_nd {{.*}}channel_first = false 
-    %0 = linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: tensor<64x26x16xf32>, tensor<3x16xf32>) outs(%arg2: tensor<64x8x16xf32>) -> tensor<64x8x16xf32>
-    return %0 : tensor<64x8x16xf32>
+    %0 = linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: tensor<64x26x16xf32>, tensor<3x16x1xf32>) outs(%arg2: tensor<64x8x16x1xf32>) -> tensor<64x8x16x1xf32>
+    return %0 : tensor<64x8x16x1xf32>
 }
 
 // -----

>From 11c2a9404fdaf45b9378586fd9a1237f636bc67d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 20:51:44 -0600
Subject: [PATCH 17/22] Update docs (still needs more elaboration)

---
 .../include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td |  7 +++----
 .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td      | 10 +++++-----
 2 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 39e0d05768e782..54d8cdf3e35d2a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -902,10 +902,9 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
   LinalgConvolutionOpInterface]> {
   let description = [{
     A depthwise convolution is defined in general terms:
-    1. it is a convolution as defined by `ConvolutionOpInterface`.
-    2. `in_channels = K * out_channels` for some integer `m`.
-    3. The dimension of the filter preceding the channel dim is equal to `K`, the depth multiplier
-    4. `input_rank == kernel_rank == output_rank + 1` (including batch dim in input and output)
+    1. It is a convolution as defined by `ConvolutionOpInterface`.
+    2. The channel and multiplier dimensions are consecutive and always in `CM` order
+    3. `input_rank == kernel_rank == output_rank + 1` (including batch dim in input and output)
   }];
   let cppNamespace = "::mlir::linalg";
   let verify = [{ return detail::verifyDepthwiseConvolutionInterface($_op); }];
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index c8e7746b17fac9..395e5219ada263 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -400,18 +400,18 @@ def DepthwiseConvNDOp : LinalgStructuredBase_Op<"depthwise_conv_nd",
     `bool` attritbute.  When placing the channel dim first or last, the batch dim is excluded.  In
     any case, the channel and spatial dims are in the same relative order for all operands.
     
-    Domain: N, S, C, KS
+    Domain: N, S, C, M, KS
 
     Layouts:
       `channel_first == true`:
         Input: `NCS`
-        Kernel: `CS`
-        Output: `NCS`
+        Kernel: `CMS`
+        Output: `NCMS`
 
       `channel_first == false`:
         Input: `NSC`
-        Kernel: `SC`
-        Output: `NSC`
+        Kernel: `SCM`
+        Output: `NSCM`
 
   }];
 

>From 221e3b3e724e271dbdae45b1b756638645d856a7 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 10 Dec 2023 09:58:01 -0600
Subject: [PATCH 18/22] Add builders

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 41 +++++++++++++++++++
 1 file changed, 41 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 395e5219ada263..09e2abaec8dc4e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -425,6 +425,47 @@ def DepthwiseConvNDOp : LinalgStructuredBase_Op<"depthwise_conv_nd",
     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
     let regions = (region AnyRegion:$region);
 
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$inits, "bool":$channel_first,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, inits,
+          attributes, DepthwiseConvNDOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$inits, "bool":$channel_first,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
+        buildStructuredOp($_builder, $_state, resultTensorTypes,
+          inputs, inits, attributes, DepthwiseConvNDOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion();
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+      "ValueRange":$inits, "Attribute":$channel_first, "Attribute":$strides, "Attribute":$dilations,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("channel_first", channel_first);
+        $_state.addAttribute("strides", strides);
+        $_state.addAttribute("dilations", dilations);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, inits,
+          attributes, DepthwiseConvNDOp::getRegionBuilder());
+      }]>
+    ];
+
     // TODO: Figure out how to move this to the interface
     let extraClassDeclaration = structuredOpsBaseDecls # [{
       void print(::mlir::OpAsmPrinter &printer) {

>From cff07a1e747ddec13c0df6b3fef65fb218b938c5 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 10 Dec 2023 10:27:11 -0600
Subject: [PATCH 19/22] Move getStridesAttr and getDilationsAttr to
 ConvolutionInterface

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |  9 ++++++--
 .../Dialect/Linalg/IR/LinalgInterfaces.td     | 20 +++++++++---------
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 21 +++++++++----------
 3 files changed, 27 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index f3ae787dccea69..4cd086ce5de073 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -28,6 +28,7 @@ namespace mlir {
 namespace linalg {
 class IteratorTypeAttr;
 class LinalgOp;
+class ConvolutionOpInterface;
 class DepthwiseConvolutionOpInterface;
 
 namespace detail {
@@ -116,10 +117,14 @@ bool isaCopyOpInterface(LinalgOp linalgOp);
 
 namespace detail {
 
+// Common implementation for ConvolutionOpInterface
+namespace convolution_impl {
+DenseIntElementsAttr getStridesAttr(ConvolutionOpInterface op);
+DenseIntElementsAttr getDilationsAttr(ConvolutionOpInterface op);
+} // namespace convolution_impl
+
 // Common implementations for DepthwiseConvolutionOpInterface
 namespace depthwise_convolution_impl {
-DenseIntElementsAttr getStridesAttr(DepthwiseConvolutionOpInterface op);
-DenseIntElementsAttr getDilationsAttr(DepthwiseConvolutionOpInterface op);
 ArrayAttr getIndexingMaps(DepthwiseConvolutionOpInterface op);
 ArrayAttr getIteratorTypes(DepthwiseConvolutionOpInterface op);
 void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 54d8cdf3e35d2a..c2e16c7622e12f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -198,6 +198,16 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
       }]
     >,
   ];
+  let extraSharedClassDeclaration = [{
+    // Returns strides attribute.
+    ::mlir::DenseIntElementsAttr getStridesAttr() {
+        return detail::convolution_impl::getStridesAttr($_op);
+    }
+    // Returns dilations attribute.
+    ::mlir::DenseIntElementsAttr getDilationsAttr() {
+        return detail::convolution_impl::getDilationsAttr($_op);
+    }
+  }];
 }
 
 
@@ -933,16 +943,6 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
         return (*this)->getAttrOfType<BoolAttr>("channel_first").getValue();
     }
   }];
-  let extraSharedClassDeclaration = [{
-    // Returns strides attribute.
-    ::mlir::DenseIntElementsAttr getStridesAttr() {
-        return detail::depthwise_convolution_impl::getStridesAttr($_op);
-    }
-    // Returns dilations attribute.
-    ::mlir::DenseIntElementsAttr getDilationsAttr() {
-        return detail::depthwise_convolution_impl::getDilationsAttr($_op);
-    }
-  }];
 }
 
 #endif // LINALG_IR_LINALGINTERFACES
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 33f4cd49eb530a..e93598ecd98d99 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -638,9 +638,8 @@ enum class MatchConvolutionResult {
 };
 } // namespace mlir::linalg::detail
 
-DenseIntElementsAttr
-mlir::linalg::detail::depthwise_convolution_impl::getStridesAttr(
-    DepthwiseConvolutionOpInterface op) {
+DenseIntElementsAttr mlir::linalg::detail::convolution_impl::getStridesAttr(
+    ConvolutionOpInterface op) {
   auto maybeStridesAttr = op->getAttrOfType<DenseIntElementsAttr>("strides");
   if (!maybeStridesAttr) {
     OpBuilder builder(op.getContext());
@@ -654,9 +653,8 @@ mlir::linalg::detail::depthwise_convolution_impl::getStridesAttr(
   return maybeStridesAttr;
 }
 
-DenseIntElementsAttr
-mlir::linalg::detail::depthwise_convolution_impl::getDilationsAttr(
-    DepthwiseConvolutionOpInterface op) {
+DenseIntElementsAttr mlir::linalg::detail::convolution_impl::getDilationsAttr(
+    ConvolutionOpInterface op) {
   auto maybeDilationsAttr =
       op->getAttrOfType<DenseIntElementsAttr>("dilations");
   if (!maybeDilationsAttr) {
@@ -701,9 +699,9 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
                       [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
   AffineExpr c = getAffineDimExpr(numSpatial + 1, ctx);
   AffineExpr m = getAffineDimExpr(numSpatial + 2, ctx);
-  SmallVector<AffineExpr> ks(
-      llvm::map_range(llvm::seq<int64_t>(numSpatial + 3, 2 * (numSpatial + 1) + 1),
-                      [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+  SmallVector<AffineExpr> ks(llvm::map_range(
+      llvm::seq<int64_t>(numSpatial + 3, 2 * (numSpatial + 1) + 1),
+      [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
   // Temp subsitute for channel position attr
   int64_t channelPos = (op.getChannelFirst()) ? 1 : numSpatial + 1;
 
@@ -719,8 +717,9 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
   }
   SmallVector<AffineExpr> kExprs(ks);
   inExprs.insert(inExprs.begin() + channelPos, c);
-  kExprs.insert(
-      channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, cm.begin(), cm.end());
+  kExprs.insert(channelPos == 0 ? kExprs.begin()
+                                : kExprs.begin() + channelPos - 1,
+                cm.begin(), cm.end());
   outExprs.insert(outExprs.begin() + channelPos, cm.begin(), cm.end());
 
   cached = Builder(ctx).getAffineMapArrayAttr(

>From dbe552e86d37575517d2ab1970e9b6549531f8ab Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 10 Dec 2023 11:29:09 -0600
Subject: [PATCH 20/22] Add tiling regression test

---
 mlir/test/Dialect/Linalg/tile-conv.mlir | 42 ++++++++++++++++++++++++-
 1 file changed, 41 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir
index 4a940f12662e6c..c24a6d77242801 100644
--- a/mlir/test/Dialect/Linalg/tile-conv.mlir
+++ b/mlir/test/Dialect/Linalg/tile-conv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -transform-interpreter -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter -canonicalize -split-input-file | FileCheck %s
 
 //  CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
 //  CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)>
@@ -41,3 +41,43 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:       linalg.conv_2d
 //  CHECK-SAME:         ins(%[[SVIN]], %[[SVKER]]
 //  CHECK-SAME:         outs(%[[SVOUT]]
+
+// -----
+
+func.func @depthwise_conv_1D(%arg0 : memref<?x?x?xf32>, %arg1 : memref<?x?x?xf32>, %arg2 : memref<?x?x?x?xf32>) {
+  linalg.depthwise_conv_nd ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 : memref<?x?x?x?xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop:2 = transform.structured.tile_using_for %0 [2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+//       CHECK: func @depthwise_conv_1D
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32>
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[BATCH:.*]] = memref.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[CHANNELS:.*]] = memref.dim %[[ARG0]], %[[C1]]
+//   CHECK-DAG:   %[[MULTIPLIER:.*]] = memref.dim %[[ARG1]], %[[C1]]
+//   CHECK-DAG:   %[[KW:.*]] = memref.dim %[[ARG1]], %[[C2]]
+//   CHECK-DAG:   %[[W:.*]] = memref.dim %[[ARG2]], %[[C3]]
+//       CHECK:   scf.for %[[I:.*]] = %[[C0]] to %[[BATCH]] step %[[C2]]
+//       CHECK:     %[[T4:.*]] = affine.min #[[MAP0]](%[[I]])[%[[BATCH]]]
+//       CHECK:     scf.for %[[J:.*]] = %[[C0]] to %[[W]] step %[[C3]]
+//   CHECK-DAG:       %[[T5:.*]] = affine.min #[[MAP1]](%[[J]])[%[[W]]]
+//   CHECK-DAG:       %[[T6:.*]] = affine.apply #[[MAP2]](%[[T5]])[%[[KW]]]
+//   CHECK-DAG:       %[[SVIN:.*]] = memref.subview %[[ARG0]][%[[I]], 0, %[[J]]] [%[[T4]], %[[CHANNELS]], %[[T6]]]
+//   CHECK-DAG:       %[[SVKER:.*]] = memref.subview %[[ARG1]][0, 0, 0] [%[[CHANNELS]], %[[MULTIPLIER]], %[[KW]]]
+//   CHECK-DAG:       %[[SVOUT:.*]] = memref.subview %[[ARG2]][%[[I]], 0, 0, %[[J]]] [%[[T4]], %[[CHANNELS]], %[[MULTIPLIER]], %[[T5]]]
+//       CHECK:       linalg.depthwise_conv_nd {channel_first = true}
+//  CHECK-SAME:         ins(%[[SVIN]], %[[SVKER]]
+//  CHECK-SAME:         outs(%[[SVOUT]]

>From 1ccaf0d034d73145ddfa9e3265f856950628a7f7 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 10 Dec 2023 15:41:48 -0600
Subject: [PATCH 21/22] Add quantized depthwise op (still needs serious
 refactoring)

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |   4 +-
 .../Dialect/Linalg/IR/LinalgInterfaces.td     |   9 +-
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 100 ++++++++++++++++++
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    |  14 ++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  54 +++++++++-
 mlir/test/Dialect/Linalg/named-ops.mlir       |  11 ++
 6 files changed, 181 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 4cd086ce5de073..2b73bd80d2340e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -129,11 +129,13 @@ ArrayAttr getIndexingMaps(DepthwiseConvolutionOpInterface op);
 ArrayAttr getIteratorTypes(DepthwiseConvolutionOpInterface op);
 void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
                    ArrayRef<NamedAttribute> attrs);
+void quantizedRegionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                   ArrayRef<NamedAttribute> attrs);
 void getEffects(
     DepthwiseConvolutionOpInterface op,
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects);
-ParseResult parse(OpAsmParser &parser, OperationState &result);
+ParseResult parse(OpAsmParser &parser, OperationState &result, bool isQuantized = false);
 void print(DepthwiseConvolutionOpInterface op, OpAsmPrinter &p);
 } // namespace depthwise_convolution_impl
 
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index c2e16c7622e12f..7e55c2cc01c3b6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -182,7 +182,7 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return $_op.getOperation()->getOperand(2);
+        return $_op.getOperation()->getOperand($_op.getOperation()->getOperands().size() - 1);
       }]
     >,
     InterfaceMethod<
@@ -934,7 +934,12 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
       /*methodBody=*/[{}],
       /*defaultImplementation=*/[{
         return detail::depthwise_convolution_impl::getIteratorTypes($_op);
-      }]>
+      }]>,
+    InterfaceMethod<[{
+      Returns whether op is quantized or not.
+    }],
+    "bool", "isQuantized", (ins)
+    >
   ];
 
   let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 09e2abaec8dc4e..608df597b40107 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -482,9 +482,109 @@ def DepthwiseConvNDOp : LinalgStructuredBase_Op<"depthwise_conv_nd",
       }
       // Implement functions necessary for DestinationStyleOpInterface.
       MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+      // DepthwiseConvolutionInterface implementations
+      bool isQuantized() {return false; }
     }];
 }
 
+def DepthwiseConvNDQOp : LinalgStructuredBase_Op<"depthwise_conv_nd_q",
+  [AttrSizedOperandSegments, LinalgDepthwiseConvolutionOpInterface]> {
+    
+  let summary = [{
+    Performs quantized N-D depthwise convolution with switchable channel position; either first or last.
+  }];
+  let description = [{
+    Allows any number of spatial dimensions but treats all of them as contiguous.  Throughout, `S`,
+    will represent all spatial dimensions.  Operand layouts are determined by the `channel_first`
+    `bool` attritbute.  When placing the channel dim first or last, the batch dim is excluded.  In
+    any case, the channel and spatial dims are in the same relative order for all operands.
+    
+    Domain: N, S, C, M, KS
+
+    Layouts:
+      `channel_first == true`:
+        Input: `NCS`
+        Kernel: `CMS`
+        Output: `NCMS`
+
+      `channel_first == false`:
+        Input: `NSC`
+        Kernel: `SCM`
+        Output: `NSCM`
+
+  }];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<TensorOrMemref>:$inits,
+      DefaultValuedAttr<BoolAttr, "true">:$channel_first,
+      OptionalAttr<I64ElementsAttr>:$strides,
+      OptionalAttr<I64ElementsAttr>:$dilations
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$inits, "bool":$channel_first,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, inits,
+          attributes, DepthwiseConvNDQOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$inits, "bool":$channel_first,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
+        buildStructuredOp($_builder, $_state, resultTensorTypes,
+          inputs, inits, attributes, DepthwiseConvNDQOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion();
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+      "ValueRange":$inits, "Attribute":$channel_first, "Attribute":$strides, "Attribute":$dilations,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("channel_first", channel_first);
+        $_state.addAttribute("strides", strides);
+        $_state.addAttribute("dilations", dilations);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, inits,
+          attributes, DepthwiseConvNDQOp::getRegionBuilder());
+      }]>
+    ];
+
+    // TODO: Figure out how to move this to the interface
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      void print(::mlir::OpAsmPrinter &printer) {
+        return detail::depthwise_convolution_impl::print(*this, printer);
+      }
+      static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
+                                        ::mlir::OperationState &result) {
+        return detail::depthwise_convolution_impl::parse(parser, result, true);
+      }
+      static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+                                mlir::ArrayRef<mlir::NamedAttribute>)>
+      getRegionBuilder() {
+        return detail::depthwise_convolution_impl::quantizedRegionBuilder;
+      }
+      // Implement functions necessary for DestinationStyleOpInterface.
+      MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+      // DepthwiseConvolutionInterface implementations
+      bool isQuantized() { return true; }
+    }];
+}
 
 //===----------------------------------------------------------------------===//
 // Transpose op.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index e93598ecd98d99..d50952f1169c83 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -722,10 +722,18 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
                 cm.begin(), cm.end());
   outExprs.insert(outExprs.begin() + channelPos, cm.begin(), cm.end());
 
-  cached = Builder(ctx).getAffineMapArrayAttr(
+  SmallVector<AffineMap> maps(
       {AffineMap::get(3 + 2 * numSpatial, 0, inExprs, ctx),
        AffineMap::get(3 + 2 * numSpatial, 0, kExprs, ctx),
        AffineMap::get(3 + 2 * numSpatial, 0, outExprs, ctx)});
+
+  if (op.isQuantized()) {
+    SmallVector<AffineMap> scalarMaps(
+        2, AffineMap::get(3 + 2 * numSpatial, 0, {}, ctx));
+    maps.insert(maps.end() - 1, scalarMaps.begin(), scalarMaps.end());
+  }
+
+  cached = Builder(ctx).getAffineMapArrayAttr(maps);
   op->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
   return cached;
 }
@@ -736,11 +744,11 @@ mlir::linalg::detail::verifyDepthwiseConvolutionInterface(Operation *op) {
     return failure();
   if (DepthwiseConvolutionOpInterface conv =
           dyn_cast<DepthwiseConvolutionOpInterface>(op)) {
-    const auto imageType = conv.image().getType().cast<ShapedType>();
+    const auto imageType = conv.image().getType().dyn_cast<ShapedType>();
     const auto imageRank = imageType.getRank();
     const auto kernelRank =
         conv.filter().getType().cast<ShapedType>().getRank();
-    const auto initType = conv.init().getType().cast<ShapedType>();
+    const auto initType = conv.init().getType().dyn_cast<ShapedType>();
     const auto initRank = initType.getRank();
     if (imageRank != kernelRank || imageRank != initRank - 1)
       return op->emitError(
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 10745a93bd9c60..f9f9aae8c9f388 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1671,17 +1671,28 @@ void mlir::linalg::detail::depthwise_convolution_impl::getEffects(
 }
 
 ParseResult mlir::linalg::detail::depthwise_convolution_impl::parse(
-    OpAsmParser &parser, OperationState &result) {
-  return ::parseNamedStructuredOp(
+    OpAsmParser &parser, OperationState &result, bool isQuantized) {
+  if (isQuantized)
+    return parseNamedStructuredOp(
+        parser, result, 5,
+        mlir::linalg::detail::depthwise_convolution_impl::
+            quantizedRegionBuilder);
+  return parseNamedStructuredOp(
       parser, result, 3,
       mlir::linalg::detail::depthwise_convolution_impl::regionBuilder);
 }
 
 void mlir::linalg::detail::depthwise_convolution_impl::print(
     DepthwiseConvolutionOpInterface op, OpAsmPrinter &p) {
-  printNamedStructuredOp(p, op.getOperation(),
-                         ValueRange{op.image(), op.filter()},
-                         ValueRange{op.init()});
+  if (op.isQuantized())
+    printNamedStructuredOp(p, op.getOperation(),
+                           ValueRange{op.image(), op.filter(),
+                                      op->getOperand(2), op->getOperand(3)},
+                           ValueRange{op.init()});
+  else
+    printNamedStructuredOp(p, op.getOperation(),
+                           ValueRange{op.image(), op.filter()},
+                           ValueRange{op.init()});
 }
 
 // Build {mul, add} region for convolution
@@ -1705,6 +1716,31 @@ void mlir::linalg::detail::depthwise_convolution_impl::regionBuilder(
   helper.yieldOutputs(yields);
 }
 
+void mlir::linalg::detail::depthwise_convolution_impl::quantizedRegionBuilder(
+    ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 5 &&
+         "DepthwiseConvNDQOp regionBuilder expects 5 args");
+  RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+  Value value1 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+                         block.getArgument(0));
+  Value value2 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+                         block.getArgument(2));
+  Value value3 = helper.buildBinaryFn(BinaryFn::sub, value1, value2);
+  Value value4 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+                         block.getArgument(1));
+  Value value5 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+                         block.getArgument(3));
+  Value value6 = helper.buildBinaryFn(BinaryFn::sub, value4, value5);
+  Value value7 = helper.buildBinaryFn(BinaryFn::mul, value3, value6);
+  Value value8 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(4), value7);
+  helper.yieldOutputs({value8});
+}
+
 // TODO: Figure out how to move this to interface
 void DepthwiseConvNDOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
@@ -1714,6 +1750,14 @@ void DepthwiseConvNDOp::getEffects(
   getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
                         getDpsInits());
 }
+void DepthwiseConvNDQOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+                        getDpsInits());
+}
 
 //===----------------------------------------------------------------------===//
 // TransposeOp
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index afbcf30c87207c..c85d143725cf45 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -54,6 +54,17 @@ func.func @gen_depthwise_channel_last_tensor(%arg0: tensor<64x26x16xf32>, %arg1:
 
 // -----
 
+// CHECK-LABEL: func @gen_depthwise_1D_channel_first_quantized_memref
+func.func @gen_depthwise_1D_channel_first_quantized_memref(%arg0: memref<64x16x10xi8>, %arg1: memref<16x1x3xi8>, %arg2: memref<64x16x1x8xi64>) {
+  // CHECK: depthwise_conv_nd_q {{.*}}channel_first = true
+    %c0 = arith.constant 0 : i32
+    %c2 = arith.constant 2 : i32
+    linalg.depthwise_conv_nd_q {channel_first = true} ins(%arg0, %arg1, %c0, %c2: memref<64x16x10xi8>, memref<16x1x3xi8>, i32, i32) outs(%arg2: memref<64x16x1x8xi64>)
+    return
+}
+
+// -----
+
 // CHECK-LABEL: func @depthwise_conv_1d_nwc_wcm
 func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> {
   %zero = arith.constant 0.000000e+00 : f32

>From d60d49aabf6b00274094bb3aa5de664c4ad110e8 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 10 Dec 2023 21:43:44 -0600
Subject: [PATCH 22/22] clang-format

---
 mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 2b73bd80d2340e..3104d0670cd5eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -130,12 +130,13 @@ ArrayAttr getIteratorTypes(DepthwiseConvolutionOpInterface op);
 void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
                    ArrayRef<NamedAttribute> attrs);
 void quantizedRegionBuilder(ImplicitLocOpBuilder &b, Block &block,
-                   ArrayRef<NamedAttribute> attrs);
+                            ArrayRef<NamedAttribute> attrs);
 void getEffects(
     DepthwiseConvolutionOpInterface op,
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects);
-ParseResult parse(OpAsmParser &parser, OperationState &result, bool isQuantized = false);
+ParseResult parse(OpAsmParser &parser, OperationState &result,
+                  bool isQuantized = false);
 void print(DepthwiseConvolutionOpInterface op, OpAsmPrinter &p);
 } // namespace depthwise_convolution_impl
 



More information about the cfe-commits mailing list