[Mlir-commits] [mlir] [Linalg] Add *Conv2D* matchers (PR #168362)

Abhishek Varma llvmlistbot at llvm.org
Wed Nov 26 05:26:54 PST 2025


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/168362

>From 7b4d4b714b43684a87f5a35dea377402f8262801 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Mon, 17 Nov 2025 06:44:55 -0600
Subject: [PATCH 1/3] [Linalg] Add *Conv2D* matchers

-- This commit is the second in the series of adding matchers
   for linalg.*conv*/*pool*. Refer: https://github.com/llvm/llvm-project/pull/163724
-- In this commit all variants of Conv2D convolution ops have been
   added.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  15 +
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 973 ++++++++++++++++--
 .../convolution/roundtrip-convolution.mlir    | 155 ++-
 3 files changed, 1046 insertions(+), 97 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index c2485a08932dd..b52b93f8cc9b9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -279,6 +279,17 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
   CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
   CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
   CONV_OP_SPECIALIZER(linalg::Conv2DOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcOp);
   CONV_OP_SPECIALIZER(linalg::Conv3DOp);
   // -----------------------------
   // Depthwise Convolution ops.
@@ -287,6 +298,10 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcQOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmQOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
   // -----------------------------
   // Pooling ops.
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 6b85e6ba0ede2..57593abac7ab0 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,8 +240,8 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
 //===----------------------------------------------------------------------===//
 
 /// Returns the BlockArgument that leads to `val`, if any. Traverses optional
-/// ext* ops.
-static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
+/// ext*/sitofp ops.
+static BlockArgument getBlockArgumentWithOptionalCastOps(Value val) {
   BlockArgument blockArg = dyn_cast<BlockArgument>(val);
   if ((blockArg))
     return blockArg;
@@ -249,18 +249,62 @@ static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
   Operation *defOp = val.getDefiningOp();
   if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
       !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
-      !dyn_cast_if_present<arith::ExtUIOp>(defOp)) {
+      !dyn_cast_if_present<arith::ExtUIOp>(defOp) &&
+      !dyn_cast_if_present<arith::SIToFPOp>(defOp)) {
     return nullptr;
   }
   return dyn_cast<BlockArgument>(defOp->getOperand(0));
 }
 
+/// Utility function to match the zero point offset body of convolution ops.
+/// It takes input the addition op and multiplication op expected in every
+/// convolution op and matches the following for both operands of multiplication
+/// op :-
+///     %a - %b
+///   where, %a and %b can have optional upcast operation.
+static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp,
+                                           Block *body) {
+  Operation *subOp1 = mulOp->getOperand(0).getDefiningOp();
+  if (!isa_and_present<arith::SubIOp, arith::SubFOp>(subOp1))
+    return false;
+  Operation *subOp2 = mulOp->getOperand(1).getDefiningOp();
+  if (!isa_and_present<arith::SubIOp, arith::SubFOp>(subOp2))
+    return false;
+  BlockArgument inputBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp1->getOperand(0));
+  BlockArgument inputScalarBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp1->getOperand(1));
+  BlockArgument filterBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp2->getOperand(0));
+  BlockArgument filterScalarBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp2->getOperand(1));
+  BlockArgument outBlockArg =
+      getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
+  if (!inputBlockArg || !inputScalarBlockArg || !filterBlockArg ||
+      !filterScalarBlockArg || !outBlockArg ||
+      inputBlockArg.getOwner() != body ||
+      inputScalarBlockArg.getOwner() != body ||
+      filterBlockArg.getOwner() != body ||
+      filterScalarBlockArg.getOwner() != body ||
+      outBlockArg.getOwner() != body || inputBlockArg.getArgNumber() != 0 ||
+      inputScalarBlockArg.getArgNumber() != 2 ||
+      filterBlockArg.getArgNumber() != 1 ||
+      filterScalarBlockArg.getArgNumber() != 3 ||
+      outBlockArg.getArgNumber() != 4)
+    return false;
+  return true;
+}
+
 /// Utility to match block body for convolution ops.
 /// The body is thus expected to yield :-
 ///     %out + (%lhs * %rhs)
 ///   where: %lhs, %rhs and %out are block arguments and
 ///          %lhs and %rhs can have optional upcast operation.
-static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
+/// NOTE: In case of zero point offset convolution ops %lhs and %rhs would be :-
+///       %input - %input_scalar
+///          where, %input_scalar can have optional upcast operation.
+static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
+                                         bool zeroPointOffset = false) {
   Operation *addOp = yieldVal.getDefiningOp();
   if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
     return false;
@@ -269,12 +313,15 @@ static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
   if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
     return false;
 
+  if (zeroPointOffset) {
+    return bodyMatcherForZeroPointOffsets(addOp, mulOp, body);
+  }
   BlockArgument lhsBlockArg =
-      getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0));
+      getBlockArgumentWithOptionalCastOps(mulOp->getOperand(0));
   BlockArgument rhsBlockArg =
-      getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1));
+      getBlockArgumentWithOptionalCastOps(mulOp->getOperand(1));
   BlockArgument outBlockArg =
-      getBlockArgumentWithOptionalExtOps(addOp->getOperand(0));
+      getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
   if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
       lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
       outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
@@ -291,9 +338,9 @@ static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
     return false;
 
   BlockArgument lhsArg =
-      getBlockArgumentWithOptionalExtOps(defOp->getOperand(0));
+      getBlockArgumentWithOptionalCastOps(defOp->getOperand(0));
   BlockArgument rhsArg =
-      getBlockArgumentWithOptionalExtOps(defOp->getOperand(1));
+      getBlockArgumentWithOptionalCastOps(defOp->getOperand(1));
   if (!lhsArg || !rhsArg || lhsArg.getOwner() != body ||
       rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 ||
       rhsArg.getArgNumber() != 0)
@@ -599,49 +646,45 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
   return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
-// #inputMap  = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
-// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)>
-// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)>
+// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
 template <>
-bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
-                                              SmallVector<int64_t> *dilations,
-                                              SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv3DOp>(op))
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwcFhwcOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(3, 1);
-  *strides = SmallVector<int64_t>(3, 1);
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
   MLIRContext *context = op->getContext();
-  AffineExpr D = getAffineDimExpr(0, context);
+  AffineExpr N = getAffineDimExpr(0, context);
   AffineExpr H = getAffineDimExpr(1, context);
   AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr d = getAffineDimExpr(3, context);
+  AffineExpr F = getAffineDimExpr(3, context);
   AffineExpr h = getAffineDimExpr(4, context);
   AffineExpr w = getAffineDimExpr(5, context);
+  AffineExpr c = getAffineDimExpr(6, context);
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
-  // Match: D * stride + d * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
-                                  /*oDim=*/0, (*dilations)[0], (*strides)[0]))
-    return false;
   // Match: H * stride + h * dilation
   if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
-                                  /*oDim=*/1, (*dilations)[1], (*strides)[1]))
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
     return false;
   // Match: W * stride + w * dilation
   if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
-                                  /*oDim=*/2, (*dilations)[2], (*strides)[2]))
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
     return false;
   // Match expected indexing maps
   if (!convLayoutMatches(
-          {/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0],
-                         H * (*strides)[1] + h * (*dilations)[1],
-                         W * (*strides)[2] + w * (*dilations)[2]},
-           /*filterMap=*/{d, h, w},
-           /*outputMap=*/{D, H, W}},
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], c},
+           /*filterMap=*/{F, h, w, c},
+           /*outputMap=*/{N, H, W, F}},
           indexingMaps, context))
     return false;
   // Match body
@@ -651,37 +694,45 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
   return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
-// #inputMap  = affine_map<(N, W, C, w) -> (N, C, W + w)>
-// #filterMap = affine_map<(N, W, C, w) -> (C, w)>
-// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)>
+// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
 template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
+  if (isa<linalg::Conv2DNhwcHwcfOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
   MLIRContext *context = op->getContext();
   AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr C = getAffineDimExpr(2, context);
-  AffineExpr w = getAffineDimExpr(3, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr F = getAffineDimExpr(3, context);
+  AffineExpr h = getAffineDimExpr(4, context);
+  AffineExpr w = getAffineDimExpr(5, context);
+  AffineExpr c = getAffineDimExpr(6, context);
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+    return false;
   // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
     return false;
   // Match expected indexing maps
   if (!convLayoutMatches(
-          {/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]},
-           /*filterMap=*/{C, w},
-           /*outputMap=*/{N, C, W}},
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], c},
+           /*filterMap=*/{h, w, c, F},
+           /*outputMap=*/{N, H, W, F}},
           indexingMaps, context))
     return false;
   // Match body
@@ -691,37 +742,196 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
   return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
-// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
-// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
-// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
+// #inputMap  = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
+// #outputMap = affine_map<(N, F, H, W, C, h, w) -> (N, F, H, W)>
 template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+  if (isa<linalg::Conv2DNchwFchwOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
   MLIRContext *context = op->getContext();
   AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr C = getAffineDimExpr(2, context);
-  AffineExpr w = getAffineDimExpr(3, context);
+  AffineExpr F = getAffineDimExpr(1, context);
+  AffineExpr H = getAffineDimExpr(2, context);
+  AffineExpr W = getAffineDimExpr(3, context);
+  AffineExpr C = getAffineDimExpr(4, context);
+  AffineExpr h = getAffineDimExpr(5, context);
+  AffineExpr w = getAffineDimExpr(6, context);
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+    return false;
   // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1]},
+           /*filterMap=*/{F, C, h, w},
+           /*outputMap=*/{N, F, H, W}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwcFhwcQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr F = getAffineDimExpr(3, context);
+  AffineExpr h = getAffineDimExpr(4, context);
+  AffineExpr w = getAffineDimExpr(5, context);
+  AffineExpr c = getAffineDimExpr(6, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
                                   /*oDim=*/1, (*dilations)[0], (*strides)[0]))
     return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+    return false;
   // Match expected indexing maps
   if (!convLayoutMatches(
-          {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
-           /*filterMap=*/{w, C},
-           /*outputMap=*/{N, W, C}},
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], c},
+           /*filterMap=*/{F, h, w, c},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
+           /*outputMap=*/{N, H, W, F}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+}
+
+// #inputMap  = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
+// #scalarMap = affine_map<(N, F, H, W, C, h, w) -> ()>
+// #outputMap = affine_map<(N, F, H, W, C, h, w) -> (N, F, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNchwFchwQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr F = getAffineDimExpr(1, context);
+  AffineExpr H = getAffineDimExpr(2, context);
+  AffineExpr W = getAffineDimExpr(3, context);
+  AffineExpr C = getAffineDimExpr(4, context);
+  AffineExpr h = getAffineDimExpr(5, context);
+  AffineExpr w = getAffineDimExpr(6, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1]},
+           /*filterMap=*/{F, C, h, w},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
+           /*outputMap=*/{N, F, H, W}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+}
+
+// #inputMap  = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, C, H + h, W + w)>
+// #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (FG, G, C, h, w)>
+// #outputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, FG, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNgchwFgchwOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr G = getAffineDimExpr(1, context);
+  AffineExpr FG = getAffineDimExpr(2, context);
+  AffineExpr H = getAffineDimExpr(3, context);
+  AffineExpr W = getAffineDimExpr(4, context);
+  AffineExpr C = getAffineDimExpr(5, context);
+  AffineExpr h = getAffineDimExpr(6, context);
+  AffineExpr w = getAffineDimExpr(7, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
+                                  /*oDim=*/4, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, G, C, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1]},
+           /*filterMap=*/{FG, G, C, h, w},
+           /*outputMap=*/{N, G, FG, H, W}},
           indexingMaps, context))
     return false;
   // Match body
@@ -731,38 +941,46 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
   return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
-// #inputMap  = affine_map<(N, W, C, CM, w) -> (N, W + w, C)>
-// #filterMap = affine_map<(N, W, C, CM, w) -> (w, C, CM)>
-// #outputMap = affine_map<(N, W, C, CM, w) -> (N, W, C, CM)>
+// #inputMap  = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, C, H + h, W + w)>
+// #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (G, FG, C, h, w)>
+// #outputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, FG, H, W)>
 template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
+  if (isa<linalg::Conv2DNgchwGfchwOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
   MLIRContext *context = op->getContext();
   AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr C = getAffineDimExpr(2, context);
-  AffineExpr CM = getAffineDimExpr(3, context);
-  AffineExpr w = getAffineDimExpr(4, context);
+  AffineExpr G = getAffineDimExpr(1, context);
+  AffineExpr FG = getAffineDimExpr(2, context);
+  AffineExpr H = getAffineDimExpr(3, context);
+  AffineExpr W = getAffineDimExpr(4, context);
+  AffineExpr C = getAffineDimExpr(5, context);
+  AffineExpr h = getAffineDimExpr(6, context);
+  AffineExpr w = getAffineDimExpr(7, context);
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, (*dilations)[0], (*strides)[0]))
+    return false;
   // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
+                                  /*oDim=*/4, (*dilations)[1], (*strides)[1]))
     return false;
   // Match expected indexing maps
   if (!convLayoutMatches(
-          {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
-           /*filterMap=*/{w, C, CM},
-           /*outputMap=*/{N, W, C, CM}},
+          {/*inputMap=*/{N, G, C, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1]},
+           /*filterMap=*/{G, FG, C, h, w},
+           /*outputMap=*/{N, G, FG, H, W}},
           indexingMaps, context))
     return false;
   // Match body
@@ -772,14 +990,15 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
   return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
-// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
-// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)>
-// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)>
+// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)>
+// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
 template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+  if (isa<linalg::Conv2DNhwcHwcfQOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
@@ -791,32 +1010,604 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
   AffineExpr N = getAffineDimExpr(0, context);
   AffineExpr H = getAffineDimExpr(1, context);
   AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr C = getAffineDimExpr(3, context);
+  AffineExpr F = getAffineDimExpr(3, context);
   AffineExpr h = getAffineDimExpr(4, context);
   AffineExpr w = getAffineDimExpr(5, context);
+  AffineExpr c = getAffineDimExpr(6, context);
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
   // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
     return false;
   // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
-                                  /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
     return false;
   // Match expected indexing maps
   if (!convLayoutMatches(
-          {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1]},
-           /*filterMap=*/{C, h, w},
-           /*outputMap=*/{N, C, H, W}},
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], c},
+           /*filterMap=*/{h, w, c, F},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
+           /*outputMap=*/{N, H, W, F}},
           indexingMaps, context))
     return false;
   // Match body
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+}
+
+// #inputMap  = affine_map<(N, H, W, G, FG, h, w, c) -> (N, H + h, W + w, G, c)>
+// #filterMap = affine_map<(N, H, W, G, FG, h, w, c) -> (G, FG, h, w, c)>
+// #scalarMap = affine_map<(N, H, W, G, FG, h, w, c) -> ()>
+// #outputMap = affine_map<(N, H, W, G, FG, h, w, c) -> (N, H, W, G, FG)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr G = getAffineDimExpr(3, context);
+  AffineExpr FG = getAffineDimExpr(4, context);
+  AffineExpr h = getAffineDimExpr(5, context);
+  AffineExpr w = getAffineDimExpr(6, context);
+  AffineExpr c = getAffineDimExpr(7, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2,
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3,
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], G, c},
+           /*filterMap=*/{G, FG, h, w, c},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
+           /*outputMap=*/{N, H, W, G, FG}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+}
+
+// #inputMap  = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, C, H + h, W + w)>
+// #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (G, FG, C, h, w)>
+// #scalarMap = affine_map<(N, G, FG, H, W, C, h, w) -> ()>
+// #outputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, FG, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNgchwGfchwQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr G = getAffineDimExpr(1, context);
+  AffineExpr FG = getAffineDimExpr(2, context);
+  AffineExpr H = getAffineDimExpr(3, context);
+  AffineExpr W = getAffineDimExpr(4, context);
+  AffineExpr C = getAffineDimExpr(5, context);
+  AffineExpr h = getAffineDimExpr(6, context);
+  AffineExpr w = getAffineDimExpr(7, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
+                                  /*oDim=*/4, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, G, C, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1]},
+           /*filterMap=*/{G, FG, C, h, w},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
+           /*outputMap=*/{N, G, FG, H, W}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+}
+
+// #inputMap  = affine_map<(N, H, W, G, FG, h, w, C) -> (N, H + h, W + w, G, C)>
+// #filterMap = affine_map<(N, H, W, G, FG, h, w, C) -> (G, FG, h, w, C)>
+// #outputMap = affine_map<(N, H, W, G, FG, h, w, C) -> (N, H, W, G, FG)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwgcGfhwcOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr G = getAffineDimExpr(3, context);
+  AffineExpr FG = getAffineDimExpr(4, context);
+  AffineExpr h = getAffineDimExpr(5, context);
+  AffineExpr w = getAffineDimExpr(6, context);
+  AffineExpr C = getAffineDimExpr(7, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], G, C},
+           /*filterMap=*/{G, FG, h, w, C},
+           /*outputMap=*/{N, H, W, G, FG}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap  = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
+// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)>
+// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
+                                              SmallVector<int64_t> *dilations,
+                                              SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv3DOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(3, 1);
+  *strides = SmallVector<int64_t>(3, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr D = getAffineDimExpr(0, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr d = getAffineDimExpr(3, context);
+  AffineExpr h = getAffineDimExpr(4, context);
+  AffineExpr w = getAffineDimExpr(5, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: D * stride + d * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
+                                  /*oDim=*/0, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+                                  /*oDim=*/1, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[2], (*strides)[2]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0],
+                         H * (*strides)[1] + h * (*dilations)[1],
+                         W * (*strides)[2] + w * (*dilations)[2]},
+           /*filterMap=*/{d, h, w},
+           /*outputMap=*/{D, H, W}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap  = affine_map<(N, W, C, w) -> (N, C, W + w)>
+// #filterMap = affine_map<(N, W, C, w) -> (C, w)>
+// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(1, 1);
+  *strides = SmallVector<int64_t>(1, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr W = getAffineDimExpr(1, context);
+  AffineExpr C = getAffineDimExpr(2, context);
+  AffineExpr w = getAffineDimExpr(3, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]},
+           /*filterMap=*/{C, w},
+           /*outputMap=*/{N, C, W}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
+// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
+// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(1, 1);
+  *strides = SmallVector<int64_t>(1, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr W = getAffineDimExpr(1, context);
+  AffineExpr C = getAffineDimExpr(2, context);
+  AffineExpr w = getAffineDimExpr(3, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
+           /*filterMap=*/{w, C},
+           /*outputMap=*/{N, W, C}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap  = affine_map<(N, W, C, CM, w) -> (N, W + w, C)>
+// #filterMap = affine_map<(N, W, C, CM, w) -> (w, C, CM)>
+// #outputMap = affine_map<(N, W, C, CM, w) -> (N, W, C, CM)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(1, 1);
+  *strides = SmallVector<int64_t>(1, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr W = getAffineDimExpr(1, context);
+  AffineExpr C = getAffineDimExpr(2, context);
+  AffineExpr CM = getAffineDimExpr(3, context);
+  AffineExpr w = getAffineDimExpr(4, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
+           /*filterMap=*/{w, C, CM},
+           /*outputMap=*/{N, W, C, CM}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr C = getAffineDimExpr(3, context);
+  AffineExpr h = getAffineDimExpr(4, context);
+  AffineExpr w = getAffineDimExpr(5, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1]},
+           /*filterMap=*/{C, h, w},
+           /*outputMap=*/{N, C, H, W}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap  = affine_map<(N, H, W, c, h, w) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, c, h, w) -> (h, w, c)>
+// #outputMap = affine_map<(N, H, W, c, h, w) -> (N, H, W, c)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr c = getAffineDimExpr(3, context);
+  AffineExpr h = getAffineDimExpr(4, context);
+  AffineExpr w = getAffineDimExpr(5, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], c},
+           /*filterMap=*/{h, w, c},
+           /*outputMap=*/{N, H, W, c}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap  = affine_map<(N, H, W, c, cm, h, w) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, c, cm, h, w) -> (h, w, c, cm)>
+// #outputMap = affine_map<(N, H, W, c, cm, h, w) -> (N, H, W, c, cm)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr c = getAffineDimExpr(3, context);
+  AffineExpr cm = getAffineDimExpr(4, context);
+  AffineExpr h = getAffineDimExpr(5, context);
+  AffineExpr w = getAffineDimExpr(6, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], c},
+           /*filterMap=*/{h, w, c, cm},
+           /*outputMap=*/{N, H, W, c, cm}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap  = affine_map<(N, H, W, c, h, w) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, c, h, w) -> (h, w, c)>
+// #filterMap = affine_map<(N, H, W, c, h, w) -> ()>
+// #outputMap = affine_map<(N, H, W, c, h, w) -> (N, H, W, c)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr c = getAffineDimExpr(3, context);
+  AffineExpr h = getAffineDimExpr(4, context);
+  AffineExpr w = getAffineDimExpr(5, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], c},
+           /*filterMap=*/{h, w, c},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
+           /*outputMap=*/{N, H, W, c}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+}
+
+// #inputMap  = affine_map<(N, H, W, c, cm, h, w) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, c, cm, h, w) -> (h, w, c, cm)>
+// #scalarMap = affine_map<(N, H, W, c, cm, h, w) -> ()>
+// #outputMap = affine_map<(N, H, W, c, cm, h, w) -> (N, H, W, c, cm)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr c = getAffineDimExpr(3, context);
+  AffineExpr cm = getAffineDimExpr(4, context);
+  AffineExpr h = getAffineDimExpr(5, context);
+  AffineExpr w = getAffineDimExpr(6, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], c},
+           /*filterMap=*/{h, w, c, cm},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
+           /*outputMap=*/{N, H, W, c, cm}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
 }
 
 // #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C)
diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
index 4b2d42a3ae4e0..11a2672d04632 100644
--- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -4,9 +4,9 @@
 
 // RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s --implicit-check-not=linalg.generic
 
-// -----------------------------
+//====================
 // Convolution ops.
-// -----------------------------
+//====================
 func.func @conv_1d(%in : tensor<?xf32>, %filter : tensor<?xf32>, %out : tensor<?xf32>) -> tensor<?xf32> {
   %0 = linalg.conv_1d
          ins(%in, %filter : tensor<?xf32>, tensor<?xf32>)
@@ -55,6 +55,97 @@ func.func @conv_2d(%in : tensor<?x?xf32>, %filter : tensor<?x?xf32>, %out : tens
 
 // -----
 
+func.func @conv_2d_nchw_fchw(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,  %arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nchw_fchw
+         {dilations = dense<[2,4]> : tensor<2xi64>, strides = dense<[3,5]> : tensor<2xi64>}
+         ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+         outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nchw_fchw
+//      CHECK:   linalg.conv_2d_nchw_fchw
+// CHECK-SAME:      dilations = dense<[2, 4]> : tensor<2xi64>, strides = dense<[3, 5]> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_nchw_fchw_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?x?xi8>, %inputzp: i32, %filterzp: i32, %init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
+  %0 = linalg.conv_2d_nchw_fchw_q
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter, %inputzp, %filterzp: tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
+         outs (%init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?xi32>
+}
+//      CHECK: @conv_2d_nchw_fchw_q
+//      CHECK:   linalg.conv_2d_nchw_fchw_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_ngchw_fgchw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.conv_2d_ngchw_fgchw
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+         outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_ngchw_fgchw
+//      CHECK:   linalg.conv_2d_ngchw_fgchw
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_ngchw_gfchw(%input: tensor<?x?x?x?x?xi8>, %filter: tensor<?x?x?x?x?xi8>, %output: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32> {
+  %0 = linalg.conv_2d_ngchw_gfchw
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?x?xi8>, tensor<?x?x?x?x?xi8>)
+         outs (%output: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?x?xi32>
+}
+//      CHECK: @conv_2d_ngchw_gfchw
+//      CHECK:   linalg.conv_2d_ngchw_gfchw
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_ngchw_gfchw_q(%input: tensor<?x?x?x?x?xi8>, %filter: tensor<?x?x?x?x?xi8>, %inputzp: i32, %filterzp: i32, %output: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32> {
+  %0 = linalg.conv_2d_ngchw_gfchw_q
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter, %inputzp, %filterzp: tensor<?x?x?x?x?xi8>, tensor<?x?x?x?x?xi8>, i32, i32)
+         outs (%output: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?x?xi32>
+}
+//      CHECK: @conv_2d_ngchw_gfchw_q
+//      CHECK:   linalg.conv_2d_ngchw_gfchw_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_nhwc_hwcf_q(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_hwcf_q
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins(%input, %filter, %inputzp, %filterzp : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, i32, i32)
+         outs(%output : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nhwc_hwcf_q
+//      CHECK:   linalg.conv_2d_nhwc_hwcf_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_nhwgc_gfhwc_q(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwgc_gfhwc_q
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins(%input, %filter, %inputzp, %filterzp : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>, i32, i32)
+         outs(%output : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nhwgc_gfhwc_q
+//      CHECK:   linalg.conv_2d_nhwgc_gfhwc_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
 func.func @conv_3d(%in : tensor<?x?x?xf32>, %filter : tensor<?x?x?xf32>, %out : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
   %0 = linalg.conv_3d
          ins(%in, %filter : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
@@ -66,9 +157,9 @@ func.func @conv_3d(%in : tensor<?x?x?xf32>, %filter : tensor<?x?x?xf32>, %out :
 
 // -----
 
-// -----------------------------
+//====================
 // Depthwise Convolution ops.
-// -----------------------------
+//====================
 func.func @depthwise_conv_1d_ncw_cw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
   %0 = linalg.depthwise_conv_1d_ncw_cw
          {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
@@ -121,6 +212,58 @@ func.func @depthwise_conv_2d_nchw_chw(%input: tensor<?x?x?x?xf16>, %filter: tens
 
 // -----
 
+func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_2d_nhwc_hwc
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?xf32>)
+         outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwc
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwc_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?xi8>, %inputzp: i32, %filterzp: i32, %output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>{
+  %res = linalg.depthwise_conv_2d_nhwc_hwc_q
+           {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+           ins(%input, %filter, %inputzp, %filterzp : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32)
+           outs(%output : tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+  return %res : tensor<?x?x?x?xi32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwc_q
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwcm(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_2d_nhwc_hwcm
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+         outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwcm
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwcm
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwcm_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x?xi8>, %arg2: tensor<?x?x?x?x?xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x?xi32> {
+  %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
+         outs(%arg2 : tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?x?xi32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwcm_q
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwcm_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
 func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %output: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> {
   %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm
          {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
@@ -134,9 +277,9 @@ func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<?x?x?x?x?xf32>, %filter:
 
 // -----
 
-// -----------------------------
+//====================
 // Pooling ops.
-// -----------------------------
+//====================
 func.func @pooling_nhwc_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
   %0 = linalg.pooling_nhwc_max
          {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}

>From ba9e125e29b11b742c43847e310a2af9e25e91d6 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 25 Nov 2025 02:32:39 -0600
Subject: [PATCH 2/3] Review comment v1.0 + Add missing tests + refactor

---
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  14 +-
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 266 +++++++++---------
 .../convolution/roundtrip-convolution.mlir    |  39 +++
 3 files changed, 179 insertions(+), 140 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index b52b93f8cc9b9..ae71e822afee7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -271,7 +271,7 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
 #define CONV_OP_SPECIALIZER(ConvOpTy)                                          \
   if (isaConvolutionOpOfType<ConvOpTy>(genericOp, &dilations, &strides))       \
     return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations,        \
-                                        strides);                              \
+                                        strides);
   // -----------------------------
   // Convolution ops.
   // -----------------------------
@@ -279,17 +279,17 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
   CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
   CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
   CONV_OP_SPECIALIZER(linalg::Conv2DOp);
-  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp);
   CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfOp);
-  CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp);
   CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp);
   CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwQOp);
-  CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp);
   CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwOp);
-  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp);
-  CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp);
   CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwQOp);
   CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp);
   CONV_OP_SPECIALIZER(linalg::Conv3DOp);
   // -----------------------------
   // Depthwise Convolution ops.
@@ -299,8 +299,8 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcOp);
-  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcQOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmQOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
   // -----------------------------
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 57593abac7ab0..0491f1b332d37 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -647,13 +647,13 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
 }
 
 // #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
-// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)>
 // #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
 template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNhwcFhwcOp>(op))
+  if (isa<linalg::Conv2DNhwcHwcfOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
@@ -683,7 +683,7 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
   if (!convLayoutMatches(
           {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
                          W * (*strides)[1] + w * (*dilations)[1], c},
-           /*filterMap=*/{F, h, w, c},
+           /*filterMap=*/{h, w, c, F},
            /*outputMap=*/{N, H, W, F}},
           indexingMaps, context))
     return false;
@@ -696,12 +696,13 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
 
 // #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
 // #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)>
+// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()>
 // #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
 template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNhwcHwcfOp>(op))
+  if (isa<linalg::Conv2DNhwcHwcfQOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
@@ -720,11 +721,11 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
   // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
                                   /*oDim=*/1, (*dilations)[0], (*strides)[0]))
     return false;
   // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
                                   /*oDim=*/2, (*dilations)[1], (*strides)[1]))
     return false;
   // Match expected indexing maps
@@ -732,6 +733,8 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
           {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
                          W * (*strides)[1] + w * (*dilations)[1], c},
            /*filterMap=*/{h, w, c, F},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
            /*outputMap=*/{N, H, W, F}},
           indexingMaps, context))
     return false;
@@ -739,17 +742,17 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
 }
 
-// #inputMap  = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
-// #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
-// #outputMap = affine_map<(N, F, H, W, C, h, w) -> (N, F, H, W)>
+// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
 template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNchwFchwOp>(op))
+  if (isa<linalg::Conv2DNhwcFhwcOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
@@ -759,28 +762,28 @@ bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
   *strides = SmallVector<int64_t>(2, 1);
   MLIRContext *context = op->getContext();
   AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr F = getAffineDimExpr(1, context);
-  AffineExpr H = getAffineDimExpr(2, context);
-  AffineExpr W = getAffineDimExpr(3, context);
-  AffineExpr C = getAffineDimExpr(4, context);
-  AffineExpr h = getAffineDimExpr(5, context);
-  AffineExpr w = getAffineDimExpr(6, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr F = getAffineDimExpr(3, context);
+  AffineExpr h = getAffineDimExpr(4, context);
+  AffineExpr w = getAffineDimExpr(5, context);
+  AffineExpr c = getAffineDimExpr(6, context);
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
   // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
-                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
     return false;
   // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
-                                  /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
     return false;
   // Match expected indexing maps
   if (!convLayoutMatches(
-          {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1]},
-           /*filterMap=*/{F, C, h, w},
-           /*outputMap=*/{N, F, H, W}},
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], c},
+           /*filterMap=*/{F, h, w, c},
+           /*outputMap=*/{N, H, W, F}},
           indexingMaps, context))
     return false;
   // Match body
@@ -841,6 +844,54 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
   return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
 }
 
+// #inputMap  = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
+// #outputMap = affine_map<(N, F, H, W, C, h, w) -> (N, F, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNchwFchwOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr F = getAffineDimExpr(1, context);
+  AffineExpr H = getAffineDimExpr(2, context);
+  AffineExpr W = getAffineDimExpr(3, context);
+  AffineExpr C = getAffineDimExpr(4, context);
+  AffineExpr h = getAffineDimExpr(5, context);
+  AffineExpr w = getAffineDimExpr(6, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1]},
+           /*filterMap=*/{F, C, h, w},
+           /*outputMap=*/{N, F, H, W}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
 // #inputMap  = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
 // #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
 // #scalarMap = affine_map<(N, F, H, W, C, h, w) -> ()>
@@ -893,13 +944,13 @@ bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
 }
 
 // #inputMap  = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, C, H + h, W + w)>
-// #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (FG, G, C, h, w)>
+// #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (G, FG, C, h, w)>
 // #outputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, FG, H, W)>
 template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNgchwFgchwOp>(op))
+  if (isa<linalg::Conv2DNgchwGfchwOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
@@ -930,7 +981,7 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
   if (!convLayoutMatches(
           {/*inputMap=*/{N, G, C, H * (*strides)[0] + h * (*dilations)[0],
                          W * (*strides)[1] + w * (*dilations)[1]},
-           /*filterMap=*/{FG, G, C, h, w},
+           /*filterMap=*/{G, FG, C, h, w},
            /*outputMap=*/{N, G, FG, H, W}},
           indexingMaps, context))
     return false;
@@ -943,12 +994,13 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
 
 // #inputMap  = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, C, H + h, W + w)>
 // #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (G, FG, C, h, w)>
+// #scalarMap = affine_map<(N, G, FG, H, W, C, h, w) -> ()>
 // #outputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, FG, H, W)>
 template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNgchwGfchwOp>(op))
+  if (isa<linalg::Conv2DNgchwGfchwQOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
@@ -980,6 +1032,8 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
           {/*inputMap=*/{N, G, C, H * (*strides)[0] + h * (*dilations)[0],
                          W * (*strides)[1] + w * (*dilations)[1]},
            /*filterMap=*/{G, FG, C, h, w},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
            /*outputMap=*/{N, G, FG, H, W}},
           indexingMaps, context))
     return false;
@@ -987,18 +1041,17 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
 }
 
-// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
-// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)>
-// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()>
-// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
+// #inputMap  = affine_map<(N, H, W, G, FG, h, w, c) -> (N, H + h, W + w, G, c)>
+// #filterMap = affine_map<(N, H, W, G, FG, h, w, c) -> (G, FG, h, w, c)>
+// #outputMap = affine_map<(N, H, W, G, FG, h, w, c) -> (N, H, W, G, FG)>
 template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNhwcHwcfQOp>(op))
+  if (isa<linalg::Conv2DNhwgcGfhwcOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
@@ -1010,35 +1063,34 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
   AffineExpr N = getAffineDimExpr(0, context);
   AffineExpr H = getAffineDimExpr(1, context);
   AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr F = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  AffineExpr c = getAffineDimExpr(6, context);
+  AffineExpr G = getAffineDimExpr(3, context);
+  AffineExpr FG = getAffineDimExpr(4, context);
+  AffineExpr h = getAffineDimExpr(5, context);
+  AffineExpr w = getAffineDimExpr(6, context);
+  AffineExpr c = getAffineDimExpr(7, context);
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
   // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2,
                                   /*oDim=*/1, (*dilations)[0], (*strides)[0]))
     return false;
   // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3,
                                   /*oDim=*/2, (*dilations)[1], (*strides)[1]))
     return false;
   // Match expected indexing maps
   if (!convLayoutMatches(
           {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], c},
-           /*filterMap=*/{h, w, c, F},
-           /*scalarMap=*/{},
-           /*scalarMap=*/{},
-           /*outputMap=*/{N, H, W, F}},
+                         W * (*strides)[1] + w * (*dilations)[1], G, c},
+           /*filterMap=*/{G, FG, h, w, c},
+           /*outputMap=*/{N, H, W, G, FG}},
           indexingMaps, context))
     return false;
   // Match body
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
 // #inputMap  = affine_map<(N, H, W, G, FG, h, w, c) -> (N, H + h, W + w, G, c)>
@@ -1094,14 +1146,13 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(
 }
 
 // #inputMap  = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, C, H + h, W + w)>
-// #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (G, FG, C, h, w)>
-// #scalarMap = affine_map<(N, G, FG, H, W, C, h, w) -> ()>
+// #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (FG, G, C, h, w)>
 // #outputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, FG, H, W)>
 template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNgchwGfchwQOp>(op))
+  if (isa<linalg::Conv2DNgchwFgchwOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
@@ -1132,9 +1183,7 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
   if (!convLayoutMatches(
           {/*inputMap=*/{N, G, C, H * (*strides)[0] + h * (*dilations)[0],
                          W * (*strides)[1] + w * (*dilations)[1]},
-           /*filterMap=*/{G, FG, C, h, w},
-           /*scalarMap=*/{},
-           /*scalarMap=*/{},
+           /*filterMap=*/{FG, G, C, h, w},
            /*outputMap=*/{N, G, FG, H, W}},
           indexingMaps, context))
     return false;
@@ -1142,55 +1191,6 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
-}
-
-// #inputMap  = affine_map<(N, H, W, G, FG, h, w, C) -> (N, H + h, W + w, G, C)>
-// #filterMap = affine_map<(N, H, W, G, FG, h, w, C) -> (G, FG, h, w, C)>
-// #outputMap = affine_map<(N, H, W, G, FG, h, w, C) -> (N, H, W, G, FG)>
-template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(
-    LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNhwgcGfhwcOp>(op))
-    return true;
-
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
-
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr G = getAffineDimExpr(3, context);
-  AffineExpr FG = getAffineDimExpr(4, context);
-  AffineExpr h = getAffineDimExpr(5, context);
-  AffineExpr w = getAffineDimExpr(6, context);
-  AffineExpr C = getAffineDimExpr(7, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], G, C},
-           /*filterMap=*/{G, FG, h, w, C},
-           /*outputMap=*/{N, H, W, G, FG}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
   return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
@@ -1461,14 +1461,15 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(
   return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
-// #inputMap  = affine_map<(N, H, W, c, cm, h, w) -> (N, H + h, W + w, c)>
-// #filterMap = affine_map<(N, H, W, c, cm, h, w) -> (h, w, c, cm)>
-// #outputMap = affine_map<(N, H, W, c, cm, h, w) -> (N, H, W, c, cm)>
+// #inputMap  = affine_map<(N, H, W, c, h, w) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, c, h, w) -> (h, w, c)>
+// #filterMap = affine_map<(N, H, W, c, h, w) -> ()>
+// #outputMap = affine_map<(N, H, W, c, h, w) -> (N, H, W, c)>
 template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op))
+  if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
@@ -1481,9 +1482,8 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
   AffineExpr H = getAffineDimExpr(1, context);
   AffineExpr W = getAffineDimExpr(2, context);
   AffineExpr c = getAffineDimExpr(3, context);
-  AffineExpr cm = getAffineDimExpr(4, context);
-  AffineExpr h = getAffineDimExpr(5, context);
-  AffineExpr w = getAffineDimExpr(6, context);
+  AffineExpr h = getAffineDimExpr(4, context);
+  AffineExpr w = getAffineDimExpr(5, context);
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
   // Match: H * stride + h * dilation
@@ -1498,26 +1498,27 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
   if (!convLayoutMatches(
           {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
                          W * (*strides)[1] + w * (*dilations)[1], c},
-           /*filterMap=*/{h, w, c, cm},
-           /*outputMap=*/{N, H, W, c, cm}},
+           /*filterMap=*/{h, w, c},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
+           /*outputMap=*/{N, H, W, c}},
           indexingMaps, context))
     return false;
   // Match body
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
 }
 
-// #inputMap  = affine_map<(N, H, W, c, h, w) -> (N, H + h, W + w, c)>
-// #filterMap = affine_map<(N, H, W, c, h, w) -> (h, w, c)>
-// #filterMap = affine_map<(N, H, W, c, h, w) -> ()>
-// #outputMap = affine_map<(N, H, W, c, h, w) -> (N, H, W, c)>
+// #inputMap  = affine_map<(N, H, W, c, cm, h, w) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, c, cm, h, w) -> (h, w, c, cm)>
+// #outputMap = affine_map<(N, H, W, c, cm, h, w) -> (N, H, W, c, cm)>
 template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op))
+  if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
@@ -1530,8 +1531,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
   AffineExpr H = getAffineDimExpr(1, context);
   AffineExpr W = getAffineDimExpr(2, context);
   AffineExpr c = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
+  AffineExpr cm = getAffineDimExpr(4, context);
+  AffineExpr h = getAffineDimExpr(5, context);
+  AffineExpr w = getAffineDimExpr(6, context);
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
   // Match: H * stride + h * dilation
@@ -1546,17 +1548,15 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
   if (!convLayoutMatches(
           {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
                          W * (*strides)[1] + w * (*dilations)[1], c},
-           /*filterMap=*/{h, w, c},
-           /*scalarMap=*/{},
-           /*scalarMap=*/{},
-           /*outputMap=*/{N, H, W, c}},
+           /*filterMap=*/{h, w, c, cm},
+           /*outputMap=*/{N, H, W, c, cm}},
           indexingMaps, context))
     return false;
   // Match body
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
 // #inputMap  = affine_map<(N, H, W, c, cm, h, w) -> (N, H + h, W + w, c)>
diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
index 11a2672d04632..f602668af7daa 100644
--- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -81,6 +81,32 @@ func.func @conv_2d_nchw_fchw_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?
 
 // -----
 
+func.func @conv_2d_nhwc_fhwc(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,  %arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc
+         {dilations = dense<[2,4]> : tensor<2xi64>, strides = dense<[3,5]> : tensor<2xi64>}
+         ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+         outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nhwc_fhwc
+//      CHECK:   linalg.conv_2d_nhwc_fhwc
+// CHECK-SAME:      dilations = dense<[2, 4]> : tensor<2xi64>, strides = dense<[3, 5]> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_nhwc_fhwc_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?x?xi8>, %inputzp: i32, %filterzp: i32, %init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
+  %0 = linalg.conv_2d_nhwc_fhwc_q
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins (%input, %filter, %inputzp, %filterzp: tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
+         outs (%init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?xi32>
+}
+//      CHECK: @conv_2d_nhwc_fhwc_q
+//      CHECK:   linalg.conv_2d_nhwc_fhwc_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
 func.func @conv_2d_ngchw_fgchw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
   %0 = linalg.conv_2d_ngchw_fgchw
          {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
@@ -133,6 +159,19 @@ func.func @conv_2d_nhwc_hwcf_q(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x
 
 // -----
 
+func.func @conv_2d_nhwgc_gfhwc(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwgc_gfhwc
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+         ins(%input, %filter : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+         outs(%output : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nhwgc_gfhwc
+//      CHECK:   linalg.conv_2d_nhwgc_gfhwc
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
 func.func @conv_2d_nhwgc_gfhwc_q(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
   %0 = linalg.conv_2d_nhwgc_gfhwc_q
          {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}

>From 3c90eee9e78f8229607255cb6eec1294755e7726 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 26 Nov 2025 07:17:52 -0600
Subject: [PATCH 3/3] Hanhan refactor suggestion

---
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 1484 ++++++++---------------
 1 file changed, 527 insertions(+), 957 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 0491f1b332d37..cd303a7e7caf1 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -463,10 +463,6 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
   return false;
 }
 
-// ---------------------------------------------
-// Matchers for specific convolution operation.
-// ---------------------------------------------
-
 /// Returns true if the given indexing maps matches with the expected indexing
 /// maps.
 static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
@@ -481,6 +477,92 @@ static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
                           })));
 }
 
+/// Enum of all kinds of Pooling Op's type.
+enum PoolingType {
+  NONE,
+  MAX_SIGNED,
+  MAX_UNSIGNED,
+  MIN_SIGNED,
+  MIN_UNSIGNED,
+  SUM
+};
+
+/// Helper class for building convolution op matchers with minimal boilerplate.
+/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants.
+class ConvMatcherBuilder {
+  LinalgOp op;
+  MLIRContext *ctx;
+  SmallVector<int64_t> *dilations, *strides;
+  ArrayAttr indexingMaps;
+  PoolingType poolingType;
+  bool matched = true;
+
+public:
+  ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d,
+                     SmallVector<int64_t> *s,
+                     PoolingType poolingType = PoolingType::NONE)
+      : op(op), ctx(op->getContext()), dilations(d), strides(s),
+        indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
+    *dilations = SmallVector<int64_t>(spatialRank, 1);
+    *strides = SmallVector<int64_t>(spatialRank, 1);
+  }
+
+  /// Get affine dimension expression for dimension i.
+  AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); }
+
+  /// Build strided expression: base * stride[idx] + kernel * dilation[idx]
+  AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) {
+    return base * (*strides)[idx] + kernel * (*dilations)[idx];
+  }
+
+  /// Match stride/dilation pattern for a spatial dimension.
+  /// Returns *this for method chaining.
+  ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim,
+                                  unsigned idx) {
+    if (matched) {
+      matched = matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
+                                           (*dilations)[idx], (*strides)[idx]);
+    }
+    return *this;
+  }
+
+  /// Match expected indexing maps layout.
+  /// Returns *this for method chaining.
+  ConvMatcherBuilder &expectMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
+    if (matched)
+      matched = convLayoutMatches(maps, indexingMaps, ctx);
+    return *this;
+  }
+
+  /// Match body pattern. This should be called last.
+  bool matchBody(bool zeroPointOffset = false) {
+    if (!matched)
+      return false;
+    Block *body = op.getBlock();
+    auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+    switch (poolingType) {
+    case PoolingType::NONE:
+      return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body,
+                                          zeroPointOffset);
+    case PoolingType::MAX_SIGNED:
+      return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
+    case PoolingType::MAX_UNSIGNED:
+      return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body);
+    case PoolingType::MIN_SIGNED:
+      return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body);
+    case PoolingType::MIN_UNSIGNED:
+      return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body);
+    case PoolingType::SUM:
+      return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body);
+    }
+    return false;
+  }
+};
+
+// ---------------------------------------------
+// Matchers for specific convolution operation.
+// ---------------------------------------------
+
 // #inputMap = affine_map<(W, w) -> (W + w)>
 // #filterMap = affine_map<(W, w) -> (w)>
 // #outputMap = affine_map<(W, w) -> (W)>
@@ -494,29 +576,15 @@ bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr W = getAffineDimExpr(0, context);
-  AffineExpr w = getAffineDimExpr(1, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
-                                  /*oDim=*/0, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{W * (*strides)[0] + w * (*dilations)[0]},
-           /*filterMap=*/{w},
-           /*outputMap=*/{W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+  AffineExpr W = m.dim(0);
+  AffineExpr w = m.dim(1);
+
+  return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+      .expectMaps({/*inputMap=*/{m.strided(W, w, 0)},
+                   /*filterMap=*/{w},
+                   /*outputMap=*/{W}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, W, F, w, c) -> (N, W + w, c)>
@@ -532,32 +600,18 @@ bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr F = getAffineDimExpr(2, context);
-  AffineExpr w = getAffineDimExpr(3, context);
-  AffineExpr c = getAffineDimExpr(4, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], c},
-           /*filterMap=*/{w, c, F},
-           /*outputMap=*/{N, W, F}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr W = m.dim(1);
+  AffineExpr F = m.dim(2);
+  AffineExpr w = m.dim(3);
+  AffineExpr c = m.dim(4);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
+                   /*filterMap=*/{w, c, F},
+                   /*outputMap=*/{N, W, F}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, F, W, c, w) -> (N, c, W + w)>
@@ -573,32 +627,18 @@ bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr F = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr c = getAffineDimExpr(3, context);
-  AffineExpr w = getAffineDimExpr(4, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
-                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, c, W * (*strides)[0] + w * (*dilations)[0]},
-           /*filterMap=*/{F, c, w},
-           /*outputMap=*/{N, F, W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr F = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr c = m.dim(3);
+  AffineExpr w = m.dim(4);
+
+  return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
+      .expectMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
+                   /*filterMap=*/{F, c, w},
+                   /*outputMap=*/{N, F, W}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(H, W, h, w) -> (H + h, W + w)>
@@ -614,36 +654,18 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr H = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr h = getAffineDimExpr(2, context);
-  AffineExpr w = getAffineDimExpr(3, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
-                                  /*oDim=*/0, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
-                                  /*oDim=*/1, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1]},
-           /*filterMap=*/{h, w},
-           /*outputMap=*/{H, W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr H = m.dim(0);
+  AffineExpr W = m.dim(1);
+  AffineExpr h = m.dim(2);
+  AffineExpr w = m.dim(3);
+
+  return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+      .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
+                   /*filterMap=*/{h, w},
+                   /*outputMap=*/{H, W}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
@@ -659,39 +681,21 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr F = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  AffineExpr c = getAffineDimExpr(6, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], c},
-           /*filterMap=*/{h, w, c, F},
-           /*outputMap=*/{N, H, W, F}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr F = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+  AffineExpr c = m.dim(6);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+                   /*filterMap=*/{h, w, c, F},
+                   /*outputMap=*/{N, H, W, F}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
@@ -708,41 +712,23 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr F = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  AffineExpr c = getAffineDimExpr(6, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], c},
-           /*filterMap=*/{h, w, c, F},
-           /*scalarMap=*/{},
-           /*scalarMap=*/{},
-           /*outputMap=*/{N, H, W, F}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr F = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+  AffineExpr c = m.dim(6);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+                   /*filterMap=*/{h, w, c, F},
+                   /*scalarMap=*/{},
+                   /*scalarMap=*/{},
+                   /*outputMap=*/{N, H, W, F}})
+      .matchBody(/*zeroPointOffset=*/true);
 }
 
 // #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
@@ -758,39 +744,21 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr F = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  AffineExpr c = getAffineDimExpr(6, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], c},
-           /*filterMap=*/{F, h, w, c},
-           /*outputMap=*/{N, H, W, F}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr F = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+  AffineExpr c = m.dim(6);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+                   /*filterMap=*/{F, h, w, c},
+                   /*outputMap=*/{N, H, W, F}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
@@ -807,41 +775,23 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr F = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  AffineExpr c = getAffineDimExpr(6, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], c},
-           /*filterMap=*/{F, h, w, c},
-           /*scalarMap=*/{},
-           /*scalarMap=*/{},
-           /*outputMap=*/{N, H, W, F}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr F = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+  AffineExpr c = m.dim(6);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+                   /*filterMap=*/{F, h, w, c},
+                   /*scalarMap=*/{},
+                   /*scalarMap=*/{},
+                   /*outputMap=*/{N, H, W, F}})
+      .matchBody(/*zeroPointOffset=*/true);
 }
 
 // #inputMap  = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
@@ -857,39 +807,21 @@ bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr F = getAffineDimExpr(1, context);
-  AffineExpr H = getAffineDimExpr(2, context);
-  AffineExpr W = getAffineDimExpr(3, context);
-  AffineExpr C = getAffineDimExpr(4, context);
-  AffineExpr h = getAffineDimExpr(5, context);
-  AffineExpr w = getAffineDimExpr(6, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
-                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
-                                  /*oDim=*/3, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1]},
-           /*filterMap=*/{F, C, h, w},
-           /*outputMap=*/{N, F, H, W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr F = m.dim(1);
+  AffineExpr H = m.dim(2);
+  AffineExpr W = m.dim(3);
+  AffineExpr C = m.dim(4);
+  AffineExpr h = m.dim(5);
+  AffineExpr w = m.dim(6);
+
+  return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
+      .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
+                   /*filterMap=*/{F, C, h, w},
+                   /*outputMap=*/{N, F, H, W}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
@@ -906,41 +838,23 @@ bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr F = getAffineDimExpr(1, context);
-  AffineExpr H = getAffineDimExpr(2, context);
-  AffineExpr W = getAffineDimExpr(3, context);
-  AffineExpr C = getAffineDimExpr(4, context);
-  AffineExpr h = getAffineDimExpr(5, context);
-  AffineExpr w = getAffineDimExpr(6, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
-                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
-                                  /*oDim=*/3, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1]},
-           /*filterMap=*/{F, C, h, w},
-           /*scalarMap=*/{},
-           /*scalarMap=*/{},
-           /*outputMap=*/{N, F, H, W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr F = m.dim(1);
+  AffineExpr H = m.dim(2);
+  AffineExpr W = m.dim(3);
+  AffineExpr C = m.dim(4);
+  AffineExpr h = m.dim(5);
+  AffineExpr w = m.dim(6);
+
+  return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
+      .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
+                   /*filterMap=*/{F, C, h, w},
+                   /*scalarMap=*/{},
+                   /*scalarMap=*/{},
+                   /*outputMap=*/{N, F, H, W}})
+      .matchBody(/*zeroPointOffset=*/true);
 }
 
 // #inputMap  = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, C, H + h, W + w)>
@@ -956,40 +870,23 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr G = getAffineDimExpr(1, context);
-  AffineExpr FG = getAffineDimExpr(2, context);
-  AffineExpr H = getAffineDimExpr(3, context);
-  AffineExpr W = getAffineDimExpr(4, context);
-  AffineExpr C = getAffineDimExpr(5, context);
-  AffineExpr h = getAffineDimExpr(6, context);
-  AffineExpr w = getAffineDimExpr(7, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
-                                  /*oDim=*/3, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
-                                  /*oDim=*/4, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, G, C, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1]},
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr G = m.dim(1);
+  AffineExpr FG = m.dim(2);
+  AffineExpr H = m.dim(3);
+  AffineExpr W = m.dim(4);
+  AffineExpr C = m.dim(5);
+  AffineExpr h = m.dim(6);
+  AffineExpr w = m.dim(7);
+
+  return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
+      .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
+      .expectMaps(
+          {/*inputMap=*/{N, G, C, m.strided(H, h, 0), m.strided(W, w, 1)},
            /*filterMap=*/{G, FG, C, h, w},
-           /*outputMap=*/{N, G, FG, H, W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+           /*outputMap=*/{N, G, FG, H, W}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, C, H + h, W + w)>
@@ -1006,42 +903,25 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr G = getAffineDimExpr(1, context);
-  AffineExpr FG = getAffineDimExpr(2, context);
-  AffineExpr H = getAffineDimExpr(3, context);
-  AffineExpr W = getAffineDimExpr(4, context);
-  AffineExpr C = getAffineDimExpr(5, context);
-  AffineExpr h = getAffineDimExpr(6, context);
-  AffineExpr w = getAffineDimExpr(7, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
-                                  /*oDim=*/3, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
-                                  /*oDim=*/4, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, G, C, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1]},
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr G = m.dim(1);
+  AffineExpr FG = m.dim(2);
+  AffineExpr H = m.dim(3);
+  AffineExpr W = m.dim(4);
+  AffineExpr C = m.dim(5);
+  AffineExpr h = m.dim(6);
+  AffineExpr w = m.dim(7);
+
+  return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
+      .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
+      .expectMaps(
+          {/*inputMap=*/{N, G, C, m.strided(H, h, 0), m.strided(W, w, 1)},
            /*filterMap=*/{G, FG, C, h, w},
            /*scalarMap=*/{},
            /*scalarMap=*/{},
-           /*outputMap=*/{N, G, FG, H, W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+           /*outputMap=*/{N, G, FG, H, W}})
+      .matchBody(/*zeroPointOffset=*/true);
 }
 
 // #inputMap  = affine_map<(N, H, W, G, FG, h, w, c) -> (N, H + h, W + w, G, c)>
@@ -1057,40 +937,23 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr G = getAffineDimExpr(3, context);
-  AffineExpr FG = getAffineDimExpr(4, context);
-  AffineExpr h = getAffineDimExpr(5, context);
-  AffineExpr w = getAffineDimExpr(6, context);
-  AffineExpr c = getAffineDimExpr(7, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], G, c},
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr G = m.dim(3);
+  AffineExpr FG = m.dim(4);
+  AffineExpr h = m.dim(5);
+  AffineExpr w = m.dim(6);
+  AffineExpr c = m.dim(7);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
+      .expectMaps(
+          {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
            /*filterMap=*/{G, FG, h, w, c},
-           /*outputMap=*/{N, H, W, G, FG}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+           /*outputMap=*/{N, H, W, G, FG}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, H, W, G, FG, h, w, c) -> (N, H + h, W + w, G, c)>
@@ -1107,42 +970,25 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr G = getAffineDimExpr(3, context);
-  AffineExpr FG = getAffineDimExpr(4, context);
-  AffineExpr h = getAffineDimExpr(5, context);
-  AffineExpr w = getAffineDimExpr(6, context);
-  AffineExpr c = getAffineDimExpr(7, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], G, c},
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr G = m.dim(3);
+  AffineExpr FG = m.dim(4);
+  AffineExpr h = m.dim(5);
+  AffineExpr w = m.dim(6);
+  AffineExpr c = m.dim(7);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
+      .expectMaps(
+          {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
            /*filterMap=*/{G, FG, h, w, c},
            /*scalarMap=*/{},
            /*scalarMap=*/{},
-           /*outputMap=*/{N, H, W, G, FG}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+           /*outputMap=*/{N, H, W, G, FG}})
+      .matchBody(/*zeroPointOffset=*/true);
 }
 
 // #inputMap  = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, C, H + h, W + w)>
@@ -1158,40 +1004,23 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr G = getAffineDimExpr(1, context);
-  AffineExpr FG = getAffineDimExpr(2, context);
-  AffineExpr H = getAffineDimExpr(3, context);
-  AffineExpr W = getAffineDimExpr(4, context);
-  AffineExpr C = getAffineDimExpr(5, context);
-  AffineExpr h = getAffineDimExpr(6, context);
-  AffineExpr w = getAffineDimExpr(7, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
-                                  /*oDim=*/3, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
-                                  /*oDim=*/4, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, G, C, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1]},
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr G = m.dim(1);
+  AffineExpr FG = m.dim(2);
+  AffineExpr H = m.dim(3);
+  AffineExpr W = m.dim(4);
+  AffineExpr C = m.dim(5);
+  AffineExpr h = m.dim(6);
+  AffineExpr w = m.dim(7);
+
+  return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
+      .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
+      .expectMaps(
+          {/*inputMap=*/{N, G, C, m.strided(H, h, 0), m.strided(W, w, 1)},
            /*filterMap=*/{FG, G, C, h, w},
-           /*outputMap=*/{N, G, FG, H, W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+           /*outputMap=*/{N, G, FG, H, W}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
@@ -1207,43 +1036,22 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(3, 1);
-  *strides = SmallVector<int64_t>(3, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr D = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr d = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: D * stride + d * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
-                                  /*oDim=*/0, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
-                                  /*oDim=*/1, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
-                                  /*oDim=*/2, (*dilations)[2], (*strides)[2]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0],
-                         H * (*strides)[1] + h * (*dilations)[1],
-                         W * (*strides)[2] + w * (*dilations)[2]},
-           /*filterMap=*/{d, h, w},
-           /*outputMap=*/{D, H, W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
+  AffineExpr D = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr d = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+
+  return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+      .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
+      .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
+      .expectMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
+                                 m.strided(W, w, 2)},
+                   /*filterMap=*/{d, h, w},
+                   /*outputMap=*/{D, H, W}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, W, C, w) -> (N, C, W + w)>
@@ -1259,31 +1067,17 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr C = getAffineDimExpr(2, context);
-  AffineExpr w = getAffineDimExpr(3, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]},
-           /*filterMap=*/{C, w},
-           /*outputMap=*/{N, C, W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr W = m.dim(1);
+  AffineExpr C = m.dim(2);
+  AffineExpr w = m.dim(3);
+
+  return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
+      .expectMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
+                   /*filterMap=*/{C, w},
+                   /*outputMap=*/{N, C, W}})
+      .matchBody();
 }
 
 // #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
@@ -1299,31 +1093,17 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr C = getAffineDimExpr(2, context);
-  AffineExpr w = getAffineDimExpr(3, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
-           /*filterMap=*/{w, C},
-           /*outputMap=*/{N, W, C}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr W = m.dim(1);
+  AffineExpr C = m.dim(2);
+  AffineExpr w = m.dim(3);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+                   /*filterMap=*/{w, C},
+                   /*outputMap=*/{N, W, C}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, W, C, CM, w) -> (N, W + w, C)>
@@ -1339,32 +1119,18 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr C = getAffineDimExpr(2, context);
-  AffineExpr CM = getAffineDimExpr(3, context);
-  AffineExpr w = getAffineDimExpr(4, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
-           /*filterMap=*/{w, C, CM},
-           /*outputMap=*/{N, W, C, CM}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr W = m.dim(1);
+  AffineExpr C = m.dim(2);
+  AffineExpr CM = m.dim(3);
+  AffineExpr w = m.dim(4);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+                   /*filterMap=*/{w, C, CM},
+                   /*outputMap=*/{N, W, C, CM}})
+      .matchBody();
 }
 
 // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
@@ -1380,38 +1146,20 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr C = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
-                                  /*oDim=*/3, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1]},
-           /*filterMap=*/{C, h, w},
-           /*outputMap=*/{N, C, H, W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr C = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+
+  return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
+      .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
+                   /*filterMap=*/{C, h, w},
+                   /*outputMap=*/{N, C, H, W}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, H, W, c, h, w) -> (N, H + h, W + w, c)>
@@ -1427,38 +1175,20 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr c = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], c},
-           /*filterMap=*/{h, w, c},
-           /*outputMap=*/{N, H, W, c}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr c = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+                   /*filterMap=*/{h, w, c},
+                   /*outputMap=*/{N, H, W, c}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, H, W, c, h, w) -> (N, H + h, W + w, c)>
@@ -1475,40 +1205,22 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr c = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], c},
-           /*filterMap=*/{h, w, c},
-           /*scalarMap=*/{},
-           /*scalarMap=*/{},
-           /*outputMap=*/{N, H, W, c}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr c = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+                   /*filterMap=*/{h, w, c},
+                   /*scalarMap=*/{},
+                   /*scalarMap=*/{},
+                   /*outputMap=*/{N, H, W, c}})
+      .matchBody(/*zeroPointOffset=*/true);
 }
 
 // #inputMap  = affine_map<(N, H, W, c, cm, h, w) -> (N, H + h, W + w, c)>
@@ -1524,39 +1236,21 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr c = getAffineDimExpr(3, context);
-  AffineExpr cm = getAffineDimExpr(4, context);
-  AffineExpr h = getAffineDimExpr(5, context);
-  AffineExpr w = getAffineDimExpr(6, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], c},
-           /*filterMap=*/{h, w, c, cm},
-           /*outputMap=*/{N, H, W, c, cm}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr c = m.dim(3);
+  AffineExpr cm = m.dim(4);
+  AffineExpr h = m.dim(5);
+  AffineExpr w = m.dim(6);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+                   /*filterMap=*/{h, w, c, cm},
+                   /*outputMap=*/{N, H, W, c, cm}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, H, W, c, cm, h, w) -> (N, H + h, W + w, c)>
@@ -1573,41 +1267,23 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmQOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr c = getAffineDimExpr(3, context);
-  AffineExpr cm = getAffineDimExpr(4, context);
-  AffineExpr h = getAffineDimExpr(5, context);
-  AffineExpr w = getAffineDimExpr(6, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], c},
-           /*filterMap=*/{h, w, c, cm},
-           /*scalarMap=*/{},
-           /*scalarMap=*/{},
-           /*outputMap=*/{N, H, W, c, cm}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr c = m.dim(3);
+  AffineExpr cm = m.dim(4);
+  AffineExpr h = m.dim(5);
+  AffineExpr w = m.dim(6);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+                   /*filterMap=*/{h, w, c, cm},
+                   /*scalarMap=*/{},
+                   /*scalarMap=*/{},
+                   /*outputMap=*/{N, H, W, c, cm}})
+      .matchBody(/*zeroPointOffset=*/true);
 }
 
 // #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C)
@@ -1626,46 +1302,25 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(3, 1);
-  *strides = SmallVector<int64_t>(3, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr D = getAffineDimExpr(1, context);
-  AffineExpr H = getAffineDimExpr(2, context);
-  AffineExpr W = getAffineDimExpr(3, context);
-  AffineExpr CM = getAffineDimExpr(4, context);
-  AffineExpr d = getAffineDimExpr(5, context);
-  AffineExpr h = getAffineDimExpr(6, context);
-  AffineExpr w = getAffineDimExpr(7, context);
-  AffineExpr C = getAffineDimExpr(8, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: D * stride + d * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
-                                  /*oDim=*/3, (*dilations)[2], (*strides)[2]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, D * (*strides)[0] + d * (*dilations)[0],
-                         H * (*strides)[1] + h * (*dilations)[1],
-                         W * (*strides)[2] + w * (*dilations)[2], C},
-           /*filterMap=*/{d, h, w, C, CM},
-           /*outputMap=*/{N, D, H, W, C, CM}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr D = m.dim(1);
+  AffineExpr H = m.dim(2);
+  AffineExpr W = m.dim(3);
+  AffineExpr CM = m.dim(4);
+  AffineExpr d = m.dim(5);
+  AffineExpr h = m.dim(6);
+  AffineExpr w = m.dim(7);
+  AffineExpr C = m.dim(8);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
+      .expectMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+                                 m.strided(W, w, 2), C},
+                   /*filterMap=*/{d, h, w, C, CM},
+                   /*outputMap=*/{N, D, H, W, C, CM}})
+      .matchBody();
 }
 
 // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
@@ -1681,38 +1336,21 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr C = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], C},
-           /*filterMap=*/{h, w},
-           /*outputMap=*/{N, H, W, C}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForMaxSignedPoolOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+                       PoolingType::MAX_SIGNED);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr C = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                   /*filterMap=*/{h, w},
+                   /*outputMap=*/{N, H, W, C}})
+      .matchBody();
 }
 
 // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
@@ -1728,38 +1366,21 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr C = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], C},
-           /*filterMap=*/{h, w},
-           /*outputMap=*/{N, H, W, C}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForMinSignedPoolOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+                       PoolingType::MIN_SIGNED);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr C = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                   /*filterMap=*/{h, w},
+                   /*outputMap=*/{N, H, W, C}})
+      .matchBody();
 }
 
 // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
@@ -1775,38 +1396,21 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr C = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], C},
-           /*filterMap=*/{h, w},
-           /*outputMap=*/{N, H, W, C}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForSumPoolOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+                       PoolingType::SUM);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr C = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                   /*filterMap=*/{h, w},
+                   /*outputMap=*/{N, H, W, C}})
+      .matchBody();
 }
 
 // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
@@ -1822,38 +1426,21 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr C = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], C},
-           /*filterMap=*/{h, w},
-           /*outputMap=*/{N, H, W, C}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForMaxUnsignedPoolOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+                       PoolingType::MAX_UNSIGNED);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr C = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                   /*filterMap=*/{h, w},
+                   /*outputMap=*/{N, H, W, C}})
+      .matchBody();
 }
 
 // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
@@ -1869,38 +1456,21 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr C = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1], C},
-           /*filterMap=*/{h, w},
-           /*outputMap=*/{N, H, W, C}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForMinUnsignedPoolOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+                       PoolingType::MIN_UNSIGNED);
+  AffineExpr N = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr C = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                   /*filterMap=*/{h, w},
+                   /*outputMap=*/{N, H, W, C}})
+      .matchBody();
 }
 
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,



More information about the Mlir-commits mailing list