[Mlir-commits] [mlir] [Linalg] Add *Conv2D* matchers (PR #168362)

Abhishek Varma llvmlistbot at llvm.org
Sat Nov 29 08:46:33 PST 2025


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/168362

>From 968cb69151838a6cd97ec82dcecc61a11526cf57 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 28 Nov 2025 01:48:46 -0600
Subject: [PATCH] [Linalg] Add *Conv2D* matchers

-- This commit is the fourth in the series of adding matchers
   for linalg.conv/pool. Refer: #163724
-- In this commit all variants of Conv2D convolution ops have been
   added.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  15 +
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 549 +++++++++++++++++-
 .../convolution/roundtrip-convolution.mlir    | 195 +++++++
 3 files changed, 748 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index c2485a08932dd..bbfbd2e9736a1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -279,6 +279,17 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
   CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
   CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
   CONV_OP_SPECIALIZER(linalg::Conv2DOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp);
   CONV_OP_SPECIALIZER(linalg::Conv3DOp);
   // -----------------------------
   // Depthwise Convolution ops.
@@ -287,6 +298,10 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcQOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmQOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
   // -----------------------------
   // Pooling ops.
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index e85a2ab26bd32..a59b2663f2998 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,8 +240,8 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
 //===----------------------------------------------------------------------===//
 
 /// Returns the BlockArgument that leads to `val`, if any. Traverses optional
-/// ext* ops.
-static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
+/// ext*/sitofp ops.
+static BlockArgument getBlockArgumentWithOptionalCastOps(Value val) {
   BlockArgument blockArg = dyn_cast<BlockArgument>(val);
   if ((blockArg))
     return blockArg;
@@ -249,18 +249,62 @@ static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
   Operation *defOp = val.getDefiningOp();
   if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
       !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
-      !dyn_cast_if_present<arith::ExtUIOp>(defOp)) {
+      !dyn_cast_if_present<arith::ExtUIOp>(defOp) &&
+      !dyn_cast_if_present<arith::SIToFPOp>(defOp)) {
     return nullptr;
   }
   return dyn_cast<BlockArgument>(defOp->getOperand(0));
 }
 
+/// Utility function to match the zero point offset body of convolution ops.
+/// It takes input the addition op and multiplication op expected in every
+/// convolution op and matches the following for both operands of multiplication
+/// op :-
+///     %a - %b
+///   where, %a and %b can have optional upcast operation.
+static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp,
+                                           Block *body) {
+  Operation *subOp1 = mulOp->getOperand(0).getDefiningOp();
+  if (!isa_and_present<arith::SubIOp, arith::SubFOp>(subOp1))
+    return false;
+  Operation *subOp2 = mulOp->getOperand(1).getDefiningOp();
+  if (!isa_and_present<arith::SubIOp, arith::SubFOp>(subOp2))
+    return false;
+  BlockArgument inputBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp1->getOperand(0));
+  BlockArgument inputScalarBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp1->getOperand(1));
+  BlockArgument filterBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp2->getOperand(0));
+  BlockArgument filterScalarBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp2->getOperand(1));
+  BlockArgument outBlockArg =
+      getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
+  if (!inputBlockArg || !inputScalarBlockArg || !filterBlockArg ||
+      !filterScalarBlockArg || !outBlockArg ||
+      inputBlockArg.getOwner() != body ||
+      inputScalarBlockArg.getOwner() != body ||
+      filterBlockArg.getOwner() != body ||
+      filterScalarBlockArg.getOwner() != body ||
+      outBlockArg.getOwner() != body || inputBlockArg.getArgNumber() != 0 ||
+      inputScalarBlockArg.getArgNumber() != 2 ||
+      filterBlockArg.getArgNumber() != 1 ||
+      filterScalarBlockArg.getArgNumber() != 3 ||
+      outBlockArg.getArgNumber() != 4)
+    return false;
+  return true;
+}
+
 /// Utility to match block body for convolution ops.
 /// The body is thus expected to yield :-
 ///     %out + (%lhs * %rhs)
 ///   where: %lhs, %rhs and %out are block arguments and
 ///          %lhs and %rhs can have optional upcast operation.
-static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
+/// NOTE: In case of zero point offset convolution ops %lhs and %rhs would be :-
+///       %input - %input_scalar
+///          where, %input_scalar can have optional upcast operation.
+static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
+                                         bool zeroPointOffset = false) {
   Operation *addOp = yieldVal.getDefiningOp();
   if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
     return false;
@@ -269,12 +313,15 @@ static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
   if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
     return false;
 
+  if (zeroPointOffset) {
+    return bodyMatcherForZeroPointOffsets(addOp, mulOp, body);
+  }
   BlockArgument lhsBlockArg =
-      getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0));
+      getBlockArgumentWithOptionalCastOps(mulOp->getOperand(0));
   BlockArgument rhsBlockArg =
-      getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1));
+      getBlockArgumentWithOptionalCastOps(mulOp->getOperand(1));
   BlockArgument outBlockArg =
-      getBlockArgumentWithOptionalExtOps(addOp->getOperand(0));
+      getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
   if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
       lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
       outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
@@ -291,9 +338,9 @@ static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
     return false;
 
   BlockArgument lhsArg =
-      getBlockArgumentWithOptionalExtOps(defOp->getOperand(0));
+      getBlockArgumentWithOptionalCastOps(defOp->getOperand(0));
   BlockArgument rhsArg =
-      getBlockArgumentWithOptionalExtOps(defOp->getOperand(1));
+      getBlockArgumentWithOptionalCastOps(defOp->getOperand(1));
   if (!lhsArg || !rhsArg || lhsArg.getOwner() != body ||
       rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 ||
       rhsArg.getArgNumber() != 0)
@@ -488,14 +535,15 @@ class ConvMatcherBuilder {
   }
 
   /// Match body pattern. This should be called last.
-  bool matchBody() {
+  bool matchBody(bool zeroPointOffset = false) {
     if (!matched)
       return false;
     Block *body = op.getBlock();
     auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
     switch (poolingType) {
     case PoolingType::NONE:
-      return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body);
+      return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body,
+                                          zeroPointOffset);
     case PoolingType::MAX_SIGNED:
       return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
     case PoolingType::MAX_UNSIGNED:
@@ -620,6 +668,361 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
       .matchBody();
 }
 
+// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwcHwcfOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr F = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+  AffineExpr c = m.dim(6);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+                   /*filterMap=*/{h, w, c, F},
+                   /*outputMap=*/{N, H, W, F}})
+      .matchBody();
+}
+
+// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)>
+// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwcHwcfQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr F = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+  AffineExpr c = m.dim(6);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+                   /*filterMap=*/{h, w, c, F},
+                   /*scalarMap=*/{},
+                   /*scalarMap=*/{},
+                   /*outputMap=*/{N, H, W, F}})
+      .matchBody(/*zeroPointOffset=*/true);
+}
+
+// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwcFhwcOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr F = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+  AffineExpr c = m.dim(6);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+                   /*filterMap=*/{F, h, w, c},
+                   /*outputMap=*/{N, H, W, F}})
+      .matchBody();
+}
+
+// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwcFhwcQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr F = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+  AffineExpr c = m.dim(6);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+                   /*filterMap=*/{F, h, w, c},
+                   /*scalarMap=*/{},
+                   /*scalarMap=*/{},
+                   /*outputMap=*/{N, H, W, F}})
+      .matchBody(/*zeroPointOffset=*/true);
+}
+
+// #inputMap  = affine_map<(N, F, H, W, c, h, w) -> (N, c, H + h, W + w)>
+// #filterMap = affine_map<(N, F, H, W, c, h, w) -> (F, c, h, w)>
+// #outputMap = affine_map<(N, F, H, W, c, h, w) -> (N, F, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNchwFchwOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr F = m.dim(1);
+  AffineExpr H = m.dim(2);
+  AffineExpr W = m.dim(3);
+  AffineExpr c = m.dim(4);
+  AffineExpr h = m.dim(5);
+  AffineExpr w = m.dim(6);
+
+  return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
+      .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
+                   /*filterMap=*/{F, c, h, w},
+                   /*outputMap=*/{N, F, H, W}})
+      .matchBody();
+}
+
+// #inputMap  = affine_map<(N, F, H, W, c, h, w) -> (N, c, H + h, W + w)>
+// #filterMap = affine_map<(N, F, H, W, c, h, w) -> (F, c, h, w)>
+// #scalarMap = affine_map<(N, F, H, W, c, h, w) -> ()>
+// #outputMap = affine_map<(N, F, H, W, c, h, w) -> (N, F, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNchwFchwQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr F = m.dim(1);
+  AffineExpr H = m.dim(2);
+  AffineExpr W = m.dim(3);
+  AffineExpr c = m.dim(4);
+  AffineExpr h = m.dim(5);
+  AffineExpr w = m.dim(6);
+
+  return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
+      .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
+                   /*filterMap=*/{F, c, h, w},
+                   /*scalarMap=*/{},
+                   /*scalarMap=*/{},
+                   /*outputMap=*/{N, F, H, W}})
+      .matchBody(/*zeroPointOffset=*/true);
+}
+
+// #inputMap  = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, c, H + h, W + w)>
+// #filterMap = affine_map<(N, G, F, H, W, c, h, w) -> (F, G, c, h, w)>
+// #outputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, F, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNgchwFgchwOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr G = m.dim(1);
+  AffineExpr F = m.dim(2);
+  AffineExpr H = m.dim(3);
+  AffineExpr W = m.dim(4);
+  AffineExpr c = m.dim(5);
+  AffineExpr h = m.dim(6);
+  AffineExpr w = m.dim(7);
+
+  return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
+      .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
+      .expectMaps(
+          {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
+           /*filterMap=*/{F, G, c, h, w},
+           /*outputMap=*/{N, G, F, H, W}})
+      .matchBody();
+}
+
+// #inputMap  = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, c, H + h, W + w)>
+// #filterMap = affine_map<(N, G, F, H, W, c, h, w) -> (G, F, c, h, w)>
+// #outputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, F, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNgchwGfchwOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr G = m.dim(1);
+  AffineExpr F = m.dim(2);
+  AffineExpr H = m.dim(3);
+  AffineExpr W = m.dim(4);
+  AffineExpr c = m.dim(5);
+  AffineExpr h = m.dim(6);
+  AffineExpr w = m.dim(7);
+
+  return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
+      .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
+      .expectMaps(
+          {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
+           /*filterMap=*/{G, F, c, h, w},
+           /*outputMap=*/{N, G, F, H, W}})
+      .matchBody();
+}
+
+// #inputMap  = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, c, H + h, W + w)>
+// #filterMap = affine_map<(N, G, F, H, W, c, h, w) -> (G, F, c, h, w)>
+// #scalarMap = affine_map<(N, G, F, H, W, c, h, w) -> ()>
+// #outputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, F, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNgchwGfchwQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr G = m.dim(1);
+  AffineExpr F = m.dim(2);
+  AffineExpr H = m.dim(3);
+  AffineExpr W = m.dim(4);
+  AffineExpr c = m.dim(5);
+  AffineExpr h = m.dim(6);
+  AffineExpr w = m.dim(7);
+
+  return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
+      .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
+      .expectMaps(
+          {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
+           /*filterMap=*/{G, F, c, h, w},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
+           /*outputMap=*/{N, G, F, H, W}})
+      .matchBody(/*zeroPointOffset=*/true);
+}
+
+// #inputMap  = affine_map<(N, H, W, G, F, h, w, c) -> (N, H + h, W + w, G, c)>
+// #filterMap = affine_map<(N, H, W, G, F, h, w, c) -> (G, F, h, w, c)>
+// #outputMap = affine_map<(N, H, W, G, F, h, w, c) -> (N, H, W, G, F)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwgcGfhwcOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr G = m.dim(3);
+  AffineExpr F = m.dim(4);
+  AffineExpr h = m.dim(5);
+  AffineExpr w = m.dim(6);
+  AffineExpr c = m.dim(7);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
+      .expectMaps(
+          {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
+           /*filterMap=*/{G, F, h, w, c},
+           /*outputMap=*/{N, H, W, G, F}})
+      .matchBody();
+}
+
+// #inputMap  = affine_map<(N, H, W, G, F, h, w, c) -> (N, H + h, W + w, G, c)>
+// #filterMap = affine_map<(N, H, W, G, F, h, w, c) -> (G, F, h, w, c)>
+// #scalarMap = affine_map<(N, H, W, G, F, h, w, c) -> ()>
+// #outputMap = affine_map<(N, H, W, G, F, h, w, c) -> (N, H, W, G, F)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr G = m.dim(3);
+  AffineExpr F = m.dim(4);
+  AffineExpr h = m.dim(5);
+  AffineExpr w = m.dim(6);
+  AffineExpr c = m.dim(7);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
+      .expectMaps(
+          {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
+           /*filterMap=*/{G, F, h, w, c},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
+           /*outputMap=*/{N, H, W, G, F}})
+      .matchBody(/*zeroPointOffset=*/true);
+}
+
 // #inputMap  = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
 // #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)>
 // #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)>
@@ -759,6 +1162,130 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
       .matchBody();
 }
 
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w, C)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr C = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                   /*filterMap=*/{h, w, C},
+                   /*outputMap=*/{N, H, W, C}})
+      .matchBody();
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w, C)>
+// #scalarMap = affine_map<(N, H, W, C, h, w) -> ()>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr C = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                   /*filterMap=*/{h, w, C},
+                   /*scalarMap=*/{},
+                   /*scalarMap=*/{},
+                   /*outputMap=*/{N, H, W, C}})
+      .matchBody(/*zeroPointOffset=*/true);
+}
+
+// #inputMap = affine_map<(N, H, W, C, CM, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, CM, h, w) -> (h, w, C, CM)>
+// #outputMap = affine_map<(N, H, W, C, CM, h, w) -> (N, H, W, C, CM)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr C = m.dim(3);
+  AffineExpr CM = m.dim(4);
+  AffineExpr h = m.dim(5);
+  AffineExpr w = m.dim(6);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                   /*filterMap=*/{h, w, C, CM},
+                   /*outputMap=*/{N, H, W, C, CM}})
+      .matchBody();
+}
+
+// #inputMap = affine_map<(N, H, W, C, CM, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, CM, h, w) -> (h, w, C, CM)>
+// #scalarMap = affine_map<(N, H, W, C, CM, h, w) -> ()>
+// #outputMap = affine_map<(N, H, W, C, CM, h, w) -> (N, H, W, C, CM)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr C = m.dim(3);
+  AffineExpr CM = m.dim(4);
+  AffineExpr h = m.dim(5);
+  AffineExpr w = m.dim(6);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                   /*filterMap=*/{h, w, C, CM},
+                   /*scalarMap=*/{},
+                   /*scalarMap=*/{},
+                   /*outputMap=*/{N, H, W, C, CM}})
+      .matchBody(/*zeroPointOffset=*/true);
+}
+
 // #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C)
 //                    -> (N, D + d, H + h, W + w, C)>
 // #filterMap = affine_map<(N, D, H, W, CM, d, h, w, C)
diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
index 4b2d42a3ae4e0..289d55ce9911a 100644
--- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -55,6 +55,149 @@ func.func @conv_2d(%in : tensor<?x?xf32>, %filter : tensor<?x?xf32>, %out : tens
 
 // -----
 
+func.func @conv_2d_nhwc_hwcf(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_hwcf
+         {dilations = dense<2> : tensor<2xi64>, strides = dense<3> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+         outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nhwc_hwcf
+//      CHECK:   linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME:      dilations = dense<2> : tensor<2xi64>, strides = dense<3> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_nhwc_hwcf_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?x?xi8>, %output: tensor<?x?x?x?xi32>, %zp_input: i32, %zp_filter: i32) -> tensor<?x?x?x?xi32> {
+  %0 = linalg.conv_2d_nhwc_hwcf_q
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter, %zp_input, %zp_filter : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
+         outs (%output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?xi32>
+}
+//      CHECK: @conv_2d_nhwc_hwcf_q
+//      CHECK:   linalg.conv_2d_nhwc_hwcf_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_nhwc_fhwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+         outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nhwc_fhwc
+//      CHECK:   linalg.conv_2d_nhwc_fhwc
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_nhwc_fhwc_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?x?xi8>, %output: tensor<?x?x?x?xi32>, %zp_input: i32, %zp_filter: i32) -> tensor<?x?x?x?xi32> {
+  %0 = linalg.conv_2d_nhwc_fhwc_q
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter, %zp_input, %zp_filter : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
+         outs (%output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?xi32>
+}
+//      CHECK: @conv_2d_nhwc_fhwc_q
+//      CHECK:   linalg.conv_2d_nhwc_fhwc_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_nchw_fchw(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nchw_fchw
+         {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[3, 4]> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+         outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nchw_fchw
+//      CHECK:   linalg.conv_2d_nchw_fchw
+// CHECK-SAME:      dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[3, 4]> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_nchw_fchw_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?x?xi8>, %output: tensor<?x?x?x?xi32>, %zp_input: i32, %zp_filter: i32) -> tensor<?x?x?x?xi32> {
+  %0 = linalg.conv_2d_nchw_fchw_q
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter, %zp_input, %zp_filter : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
+         outs (%output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?xi32>
+}
+//      CHECK: @conv_2d_nchw_fchw_q
+//      CHECK:   linalg.conv_2d_nchw_fchw_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_ngchw_fgchw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.conv_2d_ngchw_fgchw
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+         outs (%output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_ngchw_fgchw
+//      CHECK:   linalg.conv_2d_ngchw_fgchw
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_ngchw_gfchw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.conv_2d_ngchw_gfchw
+         {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+         outs (%output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_ngchw_gfchw
+//      CHECK:   linalg.conv_2d_ngchw_gfchw
+// CHECK-SAME:      dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_ngchw_gfchw_q(%input: tensor<?x?x?x?x?xi8>, %filter: tensor<?x?x?x?x?xi8>, %output: tensor<?x?x?x?x?xi32>, %zp_input: i32, %zp_filter: i32) -> tensor<?x?x?x?x?xi32> {
+  %0 = linalg.conv_2d_ngchw_gfchw_q
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter, %zp_input, %zp_filter : tensor<?x?x?x?x?xi8>, tensor<?x?x?x?x?xi8>, i32, i32)
+         outs (%output: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?x?xi32>
+}
+//      CHECK: @conv_2d_ngchw_gfchw_q
+//      CHECK:   linalg.conv_2d_ngchw_gfchw_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_nhwgc_gfhwc(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwgc_gfhwc
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+         outs (%output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nhwgc_gfhwc
+//      CHECK:   linalg.conv_2d_nhwgc_gfhwc
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_nhwgc_gfhwc_q(%input: tensor<?x?x?x?x?xi8>, %filter: tensor<?x?x?x?x?xi8>, %output: tensor<?x?x?x?x?xi32>, %zp_input: i32, %zp_filter: i32) -> tensor<?x?x?x?x?xi32> {
+  %0 = linalg.conv_2d_nhwgc_gfhwc_q
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter, %zp_input, %zp_filter : tensor<?x?x?x?x?xi8>, tensor<?x?x?x?x?xi8>, i32, i32)
+         outs (%output: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?x?xi32>
+}
+//      CHECK: @conv_2d_nhwgc_gfhwc_q
+//      CHECK:   linalg.conv_2d_nhwgc_gfhwc_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
 func.func @conv_3d(%in : tensor<?x?x?xf32>, %filter : tensor<?x?x?xf32>, %out : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
   %0 = linalg.conv_3d
          ins(%in, %filter : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
@@ -121,6 +264,58 @@ func.func @depthwise_conv_2d_nchw_chw(%input: tensor<?x?x?x?xf16>, %filter: tens
 
 // -----
 
+func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_2d_nhwc_hwc
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?xf32>)
+         outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwc
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwc_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?xi8>, %output: tensor<?x?x?x?xi32>, %zp_input: i32, %zp_filter: i32) -> tensor<?x?x?x?xi32> {
+  %0 = linalg.depthwise_conv_2d_nhwc_hwc_q
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter, %zp_input, %zp_filter : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32)
+         outs (%output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?xi32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwc_q
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwcm(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_2d_nhwc_hwcm
+         {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[3, 1]> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+         outs (%output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwcm
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwcm
+// CHECK-SAME:      dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[3, 1]> : tensor<2xi64>
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwcm_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?x?xi8>, %output: tensor<?x?x?x?x?xi32>, %zp_input: i32, %zp_filter: i32) -> tensor<?x?x?x?x?xi32> {
+  %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter, %zp_input, %zp_filter : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
+         outs (%output: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?x?xi32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwcm_q
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwcm_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
 func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %output: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> {
   %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm
          {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}



More information about the Mlir-commits mailing list