[Mlir-commits] [mlir] [Linalg] Add matchers to infer which Convolution Op a given linalg.generic is (PR #163374)

Abhishek Varma llvmlistbot at llvm.org
Wed Oct 15 01:30:31 PDT 2025


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/163374

>From a964bb437e612c57ce741157977502ed1f7815fb Mon Sep 17 00:00:00 2001
From: Abhishek Varma <avarma094 at gmail.com>
Date: Mon, 22 Sep 2025 12:16:48 +0000
Subject: [PATCH 01/18] [WIP] Generic to named Conv op support

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 158 ++++++++++++++++++
 1 file changed, 158 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 40fc0d68e358f..4e9572ee7cb04 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,6 +237,159 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
 }
 
+static bool matchingIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
+ArrayRef<mlir::utils::IteratorType> expectedIteratorTypes) {
+  if (iteratorTypes.size() != expectedIteratorTypes.size()) return false;
+  for (auto [orig, expected] : llvm::zip_equal(iteratorTypes, expectedIteratorTypes)) {
+    if (orig != expected) return false;
+  }
+  return true;
+}
+
+static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
+                                        uint32_t mapIndex, uint32_t dimIndex) {
+  auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
+  // uint32_t nResults = affineMap.getNumResults();
+  // llvm::outs()<<affineMap<<"\n";
+  // llvm::outs()<<"Total result = "<<affineMap.getNumResults()<<"\n";
+  // llvm::outs()<<"N = "<<nResults<<", dimIndex = "<<dimIndex<<"\n";
+  // llvm::outs().flush();
+  return affineMap.getResult(dimIndex);
+}
+
+static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
+  SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
+  SmallVector<utils::IteratorType> expectedIteratorTypes = {
+    utils::IteratorType::parallel, utils::IteratorType::reduction
+  };
+
+  if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
+    return "linalg.conv_1d";
+  return "";
+}
+
+static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
+  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+  if (indexingMaps.size() != 3) return "";
+  SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
+  // Conv 1D
+  // depthwise_conv_1d_ncw_cw
+  // depthwise_conv_1d_nwc_wc
+  // ["parallel", "parallel", "parallel", "reduction"]
+  SmallVector<utils::IteratorType> expectedIteratorTypes = {
+    utils::IteratorType::parallel, utils::IteratorType::parallel,
+    utils::IteratorType::parallel, utils::IteratorType::reduction
+  };
+  // inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2;
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) {
+    if (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))
+      return "linalg.depthwise_conv_1d_ncw_cw";
+    else if (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2))
+      return "linalg.depthwise_conv_1d_nwc_wc";
+  }
+
+  //
+  expectedIteratorTypes[2] = utils::IteratorType::reduction;
+  if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) {
+    return "linalg.conv_2d";
+  }
+  return "";
+}
+
+static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
+  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+  if (indexingMaps.size() != 3) return "";
+  SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
+  // "parallel", "parallel", "parallel", "reduction", "reduction"]
+  SmallVector<utils::IteratorType> expectedIteratorTypes = {
+    utils::IteratorType::parallel, utils::IteratorType::parallel,
+    utils::IteratorType::parallel, utils::IteratorType::parallel,
+    utils::IteratorType::reduction
+  };
+  if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
+    return "linalg.depthwise_conv_1d_nwc_wcm";
+
+  expectedIteratorTypes[3] = utils::IteratorType::reduction;
+  // inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2;
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) {
+    if (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))
+      return "linalg.conv_1d_nwc_wcf";
+    else if (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))
+      return "linalg.conv_1d_ncw_fcw";
+  }
+  return "";
+}
+
+static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
+  SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
+  SmallVector<utils::IteratorType> expectedIteratorTypes = {
+    utils::IteratorType::parallel, utils::IteratorType::reduction
+  };
+  if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
+    return "linalg.conv_1d";
+  return "";
+}
+
+static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
+  SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
+  SmallVector<utils::IteratorType> expectedIteratorTypes = {
+    utils::IteratorType::parallel, utils::IteratorType::reduction
+  };
+  if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
+    return "linalg.conv_1d";
+  return "";
+}
+
+static std::string inferConvolutionKind(GenericOp genericOp) {
+  SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
+  unsigned totalIterators = iteratorTypes.size();
+  switch(totalIterators) {
+    case 2:
+      return inferBasedOnRank2ConvIteratorTypes(genericOp);
+    case 4:
+      return inferBasedOnRank4ConvIteratorTypes(genericOp);
+    case 5:
+      return inferBasedOnRank5ConvIteratorTypes(genericOp);
+    case 7:
+      return inferBasedOnRank7ConvIteratorTypes(genericOp);
+    case 8:
+      return inferBasedOnRank8ConvIteratorTypes(genericOp);
+  }
+  return "";
+}
+
+// Converts linalg.generic to named linalg.*conv* where possible.
+static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
+                                                        GenericOp genericOp) {
+  std::string convKind = inferConvolutionKind(genericOp);
+  if (convKind == "") return failure();
+  SmallVector<Value> inputs = genericOp.getDpsInputs();
+  ValueRange outputs = genericOp.getDpsInits();
+  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+  SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
+                                      ? TypeRange(ValueRange(outputs))
+                                      : TypeRange{};
+  LinalgOp namedOp;
+  if (convKind == "linalg.conv_1d") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv1DOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_1d_nwc_wcf") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv1DNwcWcfOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_1d_ncw_fcw") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv1DNcwFcwOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.depthwise_conv_1d_ncw_cw") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNcwCwOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.depthwise_conv_1d_nwc_wc") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNwcWcOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_2d") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DOp>(genericOp, resultTypes, inputs, outputs);
+  }
+  return namedOp;
+
+  return failure();
+}
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -316,6 +469,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
   if (isaContractionOpInterface(genericOp)) {
     return specializeLinalgContractions(rewriter, genericOp);
   }
+
+  // Convolution - e.g. *conv*
+  if (isaConvolutionOpInterface(genericOp)) {
+    return specializeLinalgConvolutions(rewriter, genericOp);
+  }
   return failure();
 }
 

>From 89b7190e79b16954b254acca9b5f4d8f6f7c9eb6 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <avarma094 at gmail.com>
Date: Mon, 29 Sep 2025 17:27:30 +0000
Subject: [PATCH 02/18] Matching indexing maps

---
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 257 +++++++++++++-----
 1 file changed, 187 insertions(+), 70 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4e9572ee7cb04..84b080fd53535 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,33 +237,17 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
 }
 
-static bool matchingIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
-ArrayRef<mlir::utils::IteratorType> expectedIteratorTypes) {
-  if (iteratorTypes.size() != expectedIteratorTypes.size()) return false;
-  for (auto [orig, expected] : llvm::zip_equal(iteratorTypes, expectedIteratorTypes)) {
-    if (orig != expected) return false;
-  }
-  return true;
-}
-
 static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
                                         uint32_t mapIndex, uint32_t dimIndex) {
   auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
-  // uint32_t nResults = affineMap.getNumResults();
-  // llvm::outs()<<affineMap<<"\n";
-  // llvm::outs()<<"Total result = "<<affineMap.getNumResults()<<"\n";
-  // llvm::outs()<<"N = "<<nResults<<", dimIndex = "<<dimIndex<<"\n";
-  // llvm::outs().flush();
   return affineMap.getResult(dimIndex);
 }
 
 static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
-  SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
-  SmallVector<utils::IteratorType> expectedIteratorTypes = {
-    utils::IteratorType::parallel, utils::IteratorType::reduction
-  };
-
-  if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
+  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+  if (indexingMaps.size() != 3) return "";
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  if (getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0)))
     return "linalg.conv_1d";
   return "";
 }
@@ -271,74 +255,187 @@ static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
 static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
   ArrayAttr indexingMaps = genericOp.getIndexingMaps();
   if (indexingMaps.size() != 3) return "";
-  SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
-  // Conv 1D
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
   // depthwise_conv_1d_ncw_cw
+  // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))))
+    return "linalg.depthwise_conv_1d_ncw_cw";
   // depthwise_conv_1d_nwc_wc
-  // ["parallel", "parallel", "parallel", "reduction"]
-  SmallVector<utils::IteratorType> expectedIteratorTypes = {
-    utils::IteratorType::parallel, utils::IteratorType::parallel,
-    utils::IteratorType::parallel, utils::IteratorType::reduction
-  };
-  // inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2;
-  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) {
-    if (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))
-      return "linalg.depthwise_conv_1d_ncw_cw";
-    else if (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2))
-      return "linalg.depthwise_conv_1d_nwc_wc";
-  }
-
-  //
-  expectedIteratorTypes[2] = utils::IteratorType::reduction;
-  if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) {
+  // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+  // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))))
+    return "linalg.depthwise_conv_1d_nwc_wc";
+  // conv_2d
+  // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))))
     return "linalg.conv_2d";
-  }
   return "";
 }
 
 static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
   ArrayAttr indexingMaps = genericOp.getIndexingMaps();
   if (indexingMaps.size() != 3) return "";
-  SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
-  // "parallel", "parallel", "parallel", "reduction", "reduction"]
-  SmallVector<utils::IteratorType> expectedIteratorTypes = {
-    utils::IteratorType::parallel, utils::IteratorType::parallel,
-    utils::IteratorType::parallel, utils::IteratorType::parallel,
-    utils::IteratorType::reduction
-  };
-  if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
-    return "linalg.depthwise_conv_1d_nwc_wcm";
-
-  expectedIteratorTypes[3] = utils::IteratorType::reduction;
-  // inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2;
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) {
-    if (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))
-      return "linalg.conv_1d_nwc_wcf";
-    else if (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))
-      return "linalg.conv_1d_ncw_fcw";
-  }
+  // depthwise_conv_1d_nwc_wcm
+  // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) &&
+      (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 3)))
+    return "linalg.depthwise_conv_1d_nwc_wcm";
+  // conv_1d_nwc_wcf
+  // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+      (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)))
+    return "linalg.conv_1d_nwc_wcf";
+  // conv_1d_ncw_fcw
+  // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+    return "linalg.conv_1d_ncw_fcw";
   return "";
 }
 
 static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
-  SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
-  SmallVector<utils::IteratorType> expectedIteratorTypes = {
-    utils::IteratorType::parallel, utils::IteratorType::reduction
-  };
-  if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
-    return "linalg.conv_1d";
+  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+  if (indexingMaps.size() < 3) return "";
+  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+  // conv_2d_nhwc_fhwc
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+  if (indexingMaps.size() == 3 &&
+      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
+      (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3)))
+    return "linalg.conv_2d_nhwc_fhwc";
+  // conv_2d_nhwc_hwcf
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
+      (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)))
+    return "linalg.conv_2d_nhwc_hwcf";
+  // conv_2d_nchw_fchw
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+  if (indexingMaps.size() == 3 &&
+      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+      (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+    return "linalg.conv_2d_nchw_fchw";
+  // conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps)
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+  if (indexingMaps.size() == 5 &&
+      (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
+      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
+      (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3)))
+    return "linalg.conv_2d_nhwc_fhwc_q";
+  // conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps)
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+  llvm::outs()<<"Indexing map size = "<<indexingMaps.size()<<"\n";
+  llvm::outs()<<"(indexingMaps[2] == indexingMaps[3]) == "<<(indexingMaps[2] == indexingMaps[3])<<"\n";
+  llvm::outs()<<"cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() = "<<cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults()<<"\n";
+  if (indexingMaps.size() == 5 &&
+      (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
+      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+      (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+    return "linalg.conv_2d_nchw_fchw_q";
   return "";
 }
 
 static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
-  SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
-  SmallVector<utils::IteratorType> expectedIteratorTypes = {
-    utils::IteratorType::parallel, utils::IteratorType::reduction
-  };
-  if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
-    return "linalg.conv_1d";
+  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+  if (indexingMaps.size() < 3) return "";
+  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+  // conv_2d_ngchw_fgchw
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && 
+      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+      (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+    return "linalg.conv_2d_ngchw_fgchw";
+  // conv_2d_ngchw_gfchw
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+  if (indexingMaps.size() == 3 &&
+      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && 
+      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+      (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2)))
+    return "linalg.conv_2d_ngchw_gfchw";
+  // conv_2d_ngchw_gfchw_q
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+  if (indexingMaps.size() == 5 &&
+      (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
+      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && 
+      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+      (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2)))
+    return "linalg.conv_2d_ngchw_gfchw_q";
+  // conv_2d_nhwgc_gfhwc
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 4)) && 
+      (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 4)))
+    return "linalg.conv_2d_nhwgc_gfhwc";
   return "";
 }
 
@@ -382,8 +479,28 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
     namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNcwCwOp>(genericOp, resultTypes, inputs, outputs);
   } else if (convKind == "linalg.depthwise_conv_1d_nwc_wc") {
     namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNwcWcOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.depthwise_conv_1d_nwc_wcm") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNwcWcmOp>(genericOp, resultTypes, inputs, outputs);
   } else if (convKind == "linalg.conv_2d") {
     namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_2d_nhwc_fhwc") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcFhwcOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_2d_nhwc_hwcf") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcHwcfOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_2d_nchw_fchw") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNchwFchwOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_2d_nhwc_fhwc_q") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcFhwcQOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_2d_nchw_fchw_q") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNchwFchwQOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_2d_ngchw_fgchw") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwFgchwOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_2d_ngchw_gfchw") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwGfchwOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_2d_ngchw_gfchw_q") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwGfchwQOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwgcGfhwcOp>(genericOp, resultTypes, inputs, outputs);
   }
   return namedOp;
 

>From dac92f142afdfd6a7a616574e1b0983513a37ba1 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 3 Oct 2025 06:56:34 -0500
Subject: [PATCH 03/18] Conv complete -> start Pool op now

---
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 146 +++++++++++++++++-
 1 file changed, 142 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 84b080fd53535..6603967b991ab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -316,6 +316,39 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
   return "";
 }
 
+static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
+  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+  if (indexingMaps.size() < 3) return "";
+  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+  // depthwise_conv_2d_nchw_chw
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))))
+    return "linalg.depthwise_conv_2d_nchw_chw";
+  // depthwise_conv_2d_nhwc_hwc
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)))
+    return "linalg.depthwise_conv_2d_nhwc_hwc";
+  // conv_3d
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))))
+    return "linalg.conv_3d";
+  return "";
+}
+
 static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
   ArrayAttr indexingMaps = genericOp.getIndexingMaps();
   if (indexingMaps.size() < 3) return "";
@@ -370,9 +403,6 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  llvm::outs()<<"Indexing map size = "<<indexingMaps.size()<<"\n";
-  llvm::outs()<<"(indexingMaps[2] == indexingMaps[3]) == "<<(indexingMaps[2] == indexingMaps[3])<<"\n";
-  llvm::outs()<<"cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() = "<<cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults()<<"\n";
   if (indexingMaps.size() == 5 &&
       (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
       (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
@@ -381,6 +411,30 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
       (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
       (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
     return "linalg.conv_2d_nchw_fchw_q";
+  // depthwise_conv_2d_nhwc_hwcm
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+  if (indexingMaps.size() == 3 &&
+      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
+      (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4)))
+    return "linalg.depthwise_conv_2d_nhwc_hwcm";
+  // depthwise_conv_2d_nhwc_hwcm_q
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+  if (indexingMaps.size() == 5 &&
+      (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
+      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
+      (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4)))
+    return "linalg.depthwise_conv_2d_nhwc_hwcm_q";
   return "";
 }
 
@@ -397,7 +451,7 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
       (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && 
       (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
       (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
-      (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+      (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 2)))
     return "linalg.conv_2d_ngchw_fgchw";
   // conv_2d_ngchw_gfchw
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
@@ -436,6 +490,66 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
       (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 4)) && 
       (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 4)))
     return "linalg.conv_2d_nhwgc_gfhwc";
+  // depthwise_conv_3d_ncdhw_cdhw
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 4))))
+    return "linalg.depthwise_conv_3d_ncdhw_cdhw";
+  // depthwise_conv_3d_ndhwc_dhwc
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
+    return "linalg.depthwise_conv_3d_ndhwc_dhwc";
+  return "";
+}
+
+static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
+  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+  if (indexingMaps.size() < 3) return "";
+  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+  // conv_3d_ncdhw_fcdhw
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+    (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+    (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+    (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+    (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+    (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+    return "linalg.conv_3d_ncdhw_fcdhw";
+  // conv_3d_ndhwc_dhwcf
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+    (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+    (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+    (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+    (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
+    (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
+    return "linalg.conv_3d_ndhwc_dhwcf";
+  // depthwise_conv_3d_ndhwc_dhwcm
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+    (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+    (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+    (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+    (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)) &&
+    (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 5)))
+    return "linalg.depthwise_conv_3d_ndhwc_dhwcm";
   return "";
 }
 
@@ -449,10 +563,14 @@ static std::string inferConvolutionKind(GenericOp genericOp) {
       return inferBasedOnRank4ConvIteratorTypes(genericOp);
     case 5:
       return inferBasedOnRank5ConvIteratorTypes(genericOp);
+    case 6:
+      return inferBasedOnRank6ConvIteratorTypes(genericOp);
     case 7:
       return inferBasedOnRank7ConvIteratorTypes(genericOp);
     case 8:
       return inferBasedOnRank8ConvIteratorTypes(genericOp);
+    case 9:
+      return inferBasedOnRank9ConvIteratorTypes(genericOp);
   }
   return "";
 }
@@ -501,6 +619,26 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
     namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwGfchwQOp>(genericOp, resultTypes, inputs, outputs);
   } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc") {
     namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwgcGfhwcOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.depthwise_conv_2d_nchw_chw") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNchwChwOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm_q") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_3d") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_3d_ncdhw_fcdhw") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNcdhwFcdhwOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_3d_ndhwc_dhwcf") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNdhwcDhwcfOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwcm") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.depthwise_conv_3d_ncdhw_cdhw") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwc") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(genericOp, resultTypes, inputs, outputs);
   }
   return namedOp;
 

>From 789fb856517377966b58ba9b1c4fdea1d1b12324 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 7 Oct 2025 09:04:37 -0500
Subject: [PATCH 04/18] Add pooling ops to the mix - has few issues but we can
 shift to considering dilations/strides now

---
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 167 ++++++++++++++++++
 1 file changed, 167 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 6603967b991ab..2efa410e4b855 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,6 +237,39 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
 }
 
+/// Utility to match block body for linalg.pool* ops.
+template <typename... OpTypes>
+static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
+  Operation *defOp = yieldVal.getDefiningOp();
+  // if (!defOp) return false;
+  if (!(isa_and_present<OpTypes>(defOp) || ...)) return false;
+
+  BlockArgument lhsArg =  dyn_cast<BlockArgument>(defOp->getOperand(0));
+  BlockArgument rhsArg =  dyn_cast<BlockArgument>(defOp->getOperand(1));
+  if (!lhsArg || !rhsArg) return false;
+  return true;
+}
+
+static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+}
+
 static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
                                         uint32_t mapIndex, uint32_t dimIndex) {
   auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
@@ -279,6 +312,39 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) &&
       (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))))
     return "linalg.conv_2d";
+  
+  Block *body = genericOp.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  // pooling_ncw_max
+  // pooling_ncw_sum
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
+  // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
+  // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 2)))) {
+    if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
+      return "linalg.pooling_ncw_max";
+    if (bodyMatcherForSumPoolOps(yieldVal, body))
+      return "linalg.pooling_ncw_sum";
+  }
+  // pooling_nwc_max
+  // pooling_nwc_min
+  // pooling_nwc_sum
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+  // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
+  // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))) {
+    if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
+      return "linalg.pooling_nwc_max";
+    if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
+      return "linalg.pooling_nwc_min";
+    if (bodyMatcherForSumPoolOps(yieldVal, body))
+      return "linalg.pooling_nwc_sum";
+  }
   return "";
 }
 
@@ -346,6 +412,55 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
       (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
       (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))))
     return "linalg.conv_3d";
+
+  Block *body = genericOp.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  // pooling_nchw_max
+  // pooling_nchw_sum
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 3)))) {
+    if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
+      return "linalg.pooling_nchw_max";
+    if (bodyMatcherForSumPoolOps(yieldVal, body))
+      return "linalg.pooling_nchw_sum";
+  }
+  // pooling_nhwc_max
+  // pooling_nhwc_min
+  // pooling_nhwc_sum
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) {
+    if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
+      return "linalg.pooling_nhwc_max";
+    if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
+      return "linalg.pooling_nhwc_min";
+    if (bodyMatcherForSumPoolOps(yieldVal, body))
+      return "linalg.pooling_nhwc_sum";
+  }
+  // pooling_nhwc_max_unsigned
+  // pooling_nhwc_min_unsigned
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) {
+    if (bodyMatcherForMaxUnsignedPoolOps(yieldVal, body))
+      return "linalg.pooling_nhwc_max_unsigned";
+    if (bodyMatcherForMinUnsignedPoolOps(yieldVal, body))
+      return "linalg.pooling_nhwc_max_unsigned";
+  }
   return "";
 }
 
@@ -510,6 +625,28 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
       (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
       (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
     return "linalg.depthwise_conv_3d_ndhwc_dhwc";
+
+  Block *body = genericOp.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  // pooling_ndhwc_max
+  // pooling_ndhwc_min
+  // pooling_ndhwc_sum
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
+  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+      (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) {
+    if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
+      return "linalg.pooling_ndhwc_max";
+    if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
+      return "linalg.pooling_ndhwc_min";
+    if (bodyMatcherForSumPoolOps(yieldVal, body))
+      return "linalg.pooling_ndhwc_sum";
+  }
   return "";
 }
 
@@ -639,6 +776,36 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
     namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(genericOp, resultTypes, inputs, outputs);
   } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwc") {
     namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.pooling_nchw_max") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNchwMaxOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.pooling_nchw_sum") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNchwSumOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.pooling_nhwc_max") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.pooling_nhwc_min") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMinOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.pooling_nhwc_sum") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcSumOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.pooling_nhwc_max_unsigned") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.pooling_nhwc_min_unsigned") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMinUnsignedOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.pooling_ncw_max") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNcwMaxOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.pooling_ncw_sum") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNcwSumOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.pooling_nwc_max") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNwcMaxOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.pooling_nwc_min") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNwcMinOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.pooling_nwc_sum") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNwcSumOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.pooling_ndhwc_max") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNdhwcMaxOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.pooling_ndhwc_min") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNdhwcMinOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.pooling_ndhwc_sum") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNdhwcSumOp>(genericOp, resultTypes, inputs, outputs);
   }
   return namedOp;
 

>From 8e65bc6fc1a75de626a62ad72784087b57f2cd65 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 8 Oct 2025 03:42:40 -0500
Subject: [PATCH 05/18] Concisely v1.0

---
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 199 ++++++++++++------
 1 file changed, 130 insertions(+), 69 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 2efa410e4b855..01a5c3bebd146 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -273,14 +273,75 @@ static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
 static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
                                         uint32_t mapIndex, uint32_t dimIndex) {
   auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
-  return affineMap.getResult(dimIndex);
+  if (dimIndex < affineMap.getNumResults())
+    return affineMap.getResult(dimIndex);
+  return nullptr;
+}
+
+// Check if `expr` is either:
+// - a dimension expr alone (implying *1), or
+// - a multiplication of dimension expr by constant.
+bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_t &constantValue) {
+  if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
+    dim = dExpr;
+    constantValue = 1;
+    return true;
+  }
+
+  auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+  if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
+    return false;
+
+  AffineExpr lhs = mulExpr.getLHS();
+  AffineExpr rhs = mulExpr.getRHS();
+
+  if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
+    if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
+      dim = dExpr;
+      constantValue = cst.getValue();
+      return true;
+    }
+  }
+  if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
+    if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
+      dim = dExpr;
+      constantValue = cst.getValue();
+      return true;
+    }
+  }
+  return false;
+}
+
+bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim) {
+  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+  AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
+  auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
+  if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
+    return false;
+
+  AffineExpr dim0, dim1;
+  // TODO(Abhishek-Varma): Use this information in specialize.cpp.
+  int64_t c0, c1;
+
+  if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
+      isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) {
+    // Pattern matched with dims and constants extracted.
+    AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
+    AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
+    return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr));
+  }
+  return false;
+}
+
+bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) {
+  return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim);
 }
 
 static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
   ArrayAttr indexingMaps = genericOp.getIndexingMaps();
   if (indexingMaps.size() != 3) return "";
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  if (getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0)))
+  if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0))
     return "linalg.conv_1d";
   return "";
 }
@@ -295,7 +356,7 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
       (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))))
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2))
     return "linalg.depthwise_conv_1d_ncw_cw";
   // depthwise_conv_1d_nwc_wc
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
@@ -303,14 +364,14 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
       (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))))
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1))
     return "linalg.depthwise_conv_1d_nwc_wc";
   // conv_2d
   // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))))
+  if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1))
     return "linalg.conv_2d";
   
   Block *body = genericOp.getBlock();
@@ -323,7 +384,7 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
   // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
       (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 2)))) {
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2)) {
     if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
       return "linalg.pooling_ncw_max";
     if (bodyMatcherForSumPoolOps(yieldVal, body))
@@ -336,7 +397,7 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
   // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
   // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1)) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))) {
     if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
       return "linalg.pooling_nwc_max";
@@ -357,7 +418,7 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
   // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) &&
       (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 3)))
     return "linalg.depthwise_conv_1d_nwc_wcm";
@@ -366,7 +427,7 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
   // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
       (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)))
     return "linalg.conv_1d_nwc_wcf";
@@ -376,7 +437,7 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
   // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
       (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
       (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
     return "linalg.conv_1d_ncw_fcw";
   return "";
@@ -392,25 +453,25 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
       (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))))
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3))
     return "linalg.depthwise_conv_2d_nchw_chw";
   // depthwise_conv_2d_nhwc_hwc
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
       (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)))
     return "linalg.depthwise_conv_2d_nhwc_hwc";
   // conv_3d
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))))
+  if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2))
     return "linalg.conv_3d";
 
   Block *body = genericOp.getBlock();
@@ -423,8 +484,8 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
       (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 3)))) {
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3)) {
     if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
       return "linalg.pooling_nchw_max";
     if (bodyMatcherForSumPoolOps(yieldVal, body))
@@ -437,8 +498,8 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
       (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) {
     if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
       return "linalg.pooling_nhwc_max";
@@ -453,13 +514,13 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
       (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) {
     if (bodyMatcherForMaxUnsignedPoolOps(yieldVal, body))
       return "linalg.pooling_nhwc_max_unsigned";
     if (bodyMatcherForMinUnsignedPoolOps(yieldVal, body))
-      return "linalg.pooling_nhwc_max_unsigned";
+      return "linalg.pooling_nhwc_min_unsigned";
   }
   return "";
 }
@@ -474,8 +535,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
   if (indexingMaps.size() == 3 &&
       (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
       (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
       (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3)))
     return "linalg.conv_2d_nhwc_fhwc";
@@ -484,8 +545,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
       (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
       (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)))
     return "linalg.conv_2d_nhwc_hwcf";
@@ -496,8 +557,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
   if (indexingMaps.size() == 3 &&
       (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
       (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
       (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
     return "linalg.conv_2d_nchw_fchw";
   // conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps)
@@ -508,8 +569,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
   if (indexingMaps.size() == 5 &&
       (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
       (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
       (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
       (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3)))
     return "linalg.conv_2d_nhwc_fhwc_q";
@@ -522,8 +583,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
       (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
       (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
       (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
       (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
     return "linalg.conv_2d_nchw_fchw_q";
   // depthwise_conv_2d_nhwc_hwcm
@@ -532,8 +593,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
   if (indexingMaps.size() == 3 &&
       (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
       (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
       (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4)))
     return "linalg.depthwise_conv_2d_nhwc_hwcm";
@@ -545,8 +606,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
   if (indexingMaps.size() == 5 &&
       (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
       (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
       (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
       (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4)))
     return "linalg.depthwise_conv_2d_nhwc_hwcm_q";
@@ -564,8 +625,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
       (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
       (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && 
-      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
       (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 2)))
     return "linalg.conv_2d_ngchw_fgchw";
   // conv_2d_ngchw_gfchw
@@ -576,8 +637,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
       (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
       (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
       (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && 
-      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
       (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2)))
     return "linalg.conv_2d_ngchw_gfchw";
   // conv_2d_ngchw_gfchw_q
@@ -590,8 +651,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
       (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
       (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
       (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && 
-      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
       (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2)))
     return "linalg.conv_2d_ngchw_gfchw_q";
   // conv_2d_nhwgc_gfhwc
@@ -599,8 +660,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
       (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
       (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 4)) && 
       (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 4)))
@@ -611,18 +672,18 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
       (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 4))))
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4))
     return "linalg.depthwise_conv_3d_ncdhw_cdhw";
   // depthwise_conv_3d_ndhwc_dhwc
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
       (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
     return "linalg.depthwise_conv_3d_ndhwc_dhwc";
 
@@ -636,9 +697,9 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
   // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
       (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) {
     if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
       return "linalg.pooling_ndhwc_max";
@@ -660,9 +721,9 @@ static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
     (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
-    (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
-    (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
-    (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+    matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+    matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+    matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
     (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
     return "linalg.conv_3d_ncdhw_fcdhw";
   // conv_3d_ndhwc_dhwcf
@@ -670,22 +731,22 @@ static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-    (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
-    (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
-    (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
-    (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
-    (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+      (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
+      (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
     return "linalg.conv_3d_ndhwc_dhwcf";
   // depthwise_conv_3d_ndhwc_dhwcm
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
   if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-    (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
-    (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
-    (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
-    (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)) &&
-    (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 5)))
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+      (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)) &&
+      (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 5)))
     return "linalg.depthwise_conv_3d_ndhwc_dhwcm";
   return "";
 }

>From aae7e048da5f9fa34e7eb1a0de7559e823e23866 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 8 Oct 2025 06:23:27 -0500
Subject: [PATCH 06/18] Concise v2.0

---
 .../include/mlir/Dialect/Linalg/Utils/Utils.h |   4 +
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 178 ++++++++++--------
 2 files changed, 100 insertions(+), 82 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 48978eb7663d5..46f9f10789e6c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -114,6 +114,10 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
 // Fusion / Tiling utilities
 //===----------------------------------------------------------------------===//
 
+//===----------------------------------------------------------------------===//
+// Fusion / Tiling utilities
+//===----------------------------------------------------------------------===//
+
 /// The type of loops to be generated during tiling.
 enum class LinalgTilingLoopType {
   Loops = 0,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 01a5c3bebd146..a741cd126dd3b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -354,16 +354,18 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2))
     return "linalg.depthwise_conv_1d_ncw_cw";
   // depthwise_conv_1d_nwc_wc
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1))
     return "linalg.depthwise_conv_1d_nwc_wc";
   // conv_2d
@@ -382,8 +384,8 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
   // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
   // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2)) {
     if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
       return "linalg.pooling_ncw_max";
@@ -396,9 +398,9 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
   // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
   // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))) {
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2)) {
     if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
       return "linalg.pooling_nwc_max";
     if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
@@ -417,28 +419,29 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
   // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) &&
-      (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 3)))
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3))
     return "linalg.depthwise_conv_1d_nwc_wcm";
   // conv_1d_nwc_wcf
   // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
-      (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)))
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2))
     return "linalg.conv_1d_nwc_wcf";
   // conv_1d_ncw_fcw
   // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
-      (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
     return "linalg.conv_1d_ncw_fcw";
   return "";
 }
@@ -451,8 +454,9 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3))
     return "linalg.depthwise_conv_2d_nchw_chw";
@@ -460,10 +464,11 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)))
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3))
     return "linalg.depthwise_conv_2d_nhwc_hwc";
   // conv_3d
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
@@ -482,8 +487,8 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3)) {
     if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
@@ -497,10 +502,10 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) {
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) {
     if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
       return "linalg.pooling_nhwc_max";
     if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
@@ -513,10 +518,10 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) {
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) {
     if (bodyMatcherForMaxUnsignedPoolOps(yieldVal, body))
       return "linalg.pooling_nhwc_max_unsigned";
     if (bodyMatcherForMinUnsignedPoolOps(yieldVal, body))
@@ -534,32 +539,32 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
   if (indexingMaps.size() == 3 &&
-      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
-      (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3)))
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3))
     return "linalg.conv_2d_nhwc_fhwc";
   // conv_2d_nhwc_hwcf
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
-      (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)))
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3))
     return "linalg.conv_2d_nhwc_hwcf";
   // conv_2d_nchw_fchw
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
   if (indexingMaps.size() == 3 &&
-      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
-      (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
     return "linalg.conv_2d_nchw_fchw";
   // conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps)
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
@@ -568,11 +573,11 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
   if (indexingMaps.size() == 5 &&
       (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
-      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
-      (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3)))
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3))
     return "linalg.conv_2d_nhwc_fhwc_q";
   // conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps)
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
@@ -581,22 +586,23 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
   if (indexingMaps.size() == 5 &&
       (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
-      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
-      (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
     return "linalg.conv_2d_nchw_fchw_q";
   // depthwise_conv_2d_nhwc_hwcm
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
   if (indexingMaps.size() == 3 &&
-      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
-      (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4)))
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4))
     return "linalg.depthwise_conv_2d_nhwc_hwcm";
   // depthwise_conv_2d_nhwc_hwcm_q
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
@@ -605,11 +611,12 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
   if (indexingMaps.size() == 5 &&
       (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
-      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
-      (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4)))
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4))
     return "linalg.depthwise_conv_2d_nhwc_hwcm_q";
   return "";
 }
@@ -622,24 +629,26 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && 
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
-      (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 2)))
+      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2))
     return "linalg.conv_2d_ngchw_fgchw";
   // conv_2d_ngchw_gfchw
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
   if (indexingMaps.size() == 3 &&
-      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && 
+      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
-      (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2)))
+      matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2))
     return "linalg.conv_2d_ngchw_gfchw";
   // conv_2d_ngchw_gfchw_q
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
@@ -648,30 +657,33 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
   if (indexingMaps.size() == 5 &&
       (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
-      (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && 
+      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
-      (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2)))
+      matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2))
     return "linalg.conv_2d_ngchw_gfchw_q";
   // conv_2d_nhwgc_gfhwc
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
-      (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 4)) && 
-      (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 4)))
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4))
     return "linalg.conv_2d_nhwgc_gfhwc";
   // depthwise_conv_3d_ncdhw_cdhw
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-      (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4))
@@ -680,11 +692,12 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
-      (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4))
     return "linalg.depthwise_conv_3d_ndhwc_dhwc";
 
   Block *body = genericOp.getBlock();
@@ -696,11 +709,11 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
   // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
-      (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) {
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4)) {
     if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
       return "linalg.pooling_ndhwc_max";
     if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
@@ -719,34 +732,35 @@ static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
-    (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
-    matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
-    matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
-    matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
-    (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
     return "linalg.conv_3d_ncdhw_fcdhw";
   // conv_3d_ndhwc_dhwcf
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
-      (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
-      (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4))
     return "linalg.conv_3d_ndhwc_dhwcf";
   // depthwise_conv_3d_ndhwc_dhwcm
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
-  if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
-      (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)) &&
-      (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 5)))
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5))
     return "linalg.depthwise_conv_3d_ndhwc_dhwcm";
   return "";
 }

>From bafdb41a8e539d77adcf6253b1631d22c69b2d45 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 8 Oct 2025 07:08:18 -0500
Subject: [PATCH 07/18] Start pulling out into separate APIs

---
 .../include/mlir/Dialect/Linalg/Utils/Utils.h |   9 +-
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  34 +--
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 219 ++++++++++++++++++
 3 files changed, 233 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 46f9f10789e6c..222b66ca51708 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -111,9 +111,16 @@ std::optional<SmallVector<ReassociationIndices>>
 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
 
 //===----------------------------------------------------------------------===//
-// Fusion / Tiling utilities
+// Convolution matcher utilities
 //===----------------------------------------------------------------------===//
 
+bool isaConv1DOp(LinalgOp op);
+bool isaConv1DNwcWcfOp(LinalgOp op);
+bool isaConv1DNcwFcwOp(LinalgOp op);
+bool isaDepthwiseConv1DNcwCwOp(LinalgOp op);
+bool isaDepthwiseConv1DNwcWcOp(LinalgOp op);
+bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op);
+
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a741cd126dd3b..8b51b8f13ce0d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -338,35 +338,24 @@ bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned a
 }
 
 static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
-  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
-  if (indexingMaps.size() != 3) return "";
-  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0))
-    return "linalg.conv_1d";
+  if (isaConv1DOp(genericOp)) return "linalg.conv_1d";
   return "";
 }
 
 static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
   ArrayAttr indexingMaps = genericOp.getIndexingMaps();
   if (indexingMaps.size() != 3) return "";
-  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
   // depthwise_conv_1d_ncw_cw
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2))
+  if (isaDepthwiseConv1DNcwCwOp(genericOp))
     return "linalg.depthwise_conv_1d_ncw_cw";
   // depthwise_conv_1d_nwc_wc
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1))
+  if (isaDepthwiseConv1DNwcWcOp(genericOp))
     return "linalg.depthwise_conv_1d_nwc_wc";
   // conv_2d
   // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
@@ -414,34 +403,23 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
 static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
   ArrayAttr indexingMaps = genericOp.getIndexingMaps();
   if (indexingMaps.size() != 3) return "";
-  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
   // depthwise_conv_1d_nwc_wcm
   // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3))
+  if (isaDepthwiseConv1DNwcWcmOp(genericOp))
     return "linalg.depthwise_conv_1d_nwc_wcm";
   // conv_1d_nwc_wcf
   // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2))
+  if (isaConv1DNwcWcfOp(genericOp))
     return "linalg.conv_1d_nwc_wcf";
   // conv_1d_ncw_fcw
   // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
+  if (isaConv1DNcwFcwOp(genericOp))
     return "linalg.conv_1d_ncw_fcw";
   return "";
 }
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 3593b5348d268..12f88caf08fc7 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,6 +240,225 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
   return iteratorType == utils::IteratorType::reduction;
 }
 
+// -------------------------------
+// ---------- CONV ---------------
+// -------------------------------
+
+/// Utility to match block body for linalg.pool* ops.
+template <typename... OpTypes>
+static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
+  Operation *defOp = yieldVal.getDefiningOp();
+  // if (!defOp) return false;
+  if (!(isa_and_present<OpTypes>(defOp) || ...)) return false;
+
+  BlockArgument lhsArg =  dyn_cast<BlockArgument>(defOp->getOperand(0));
+  BlockArgument rhsArg =  dyn_cast<BlockArgument>(defOp->getOperand(1));
+  if (!lhsArg || !rhsArg) return false;
+  return true;
+}
+
+static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+}
+
+static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
+                                        uint32_t mapIndex, uint32_t dimIndex) {
+  auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
+  if (dimIndex < affineMap.getNumResults())
+    return affineMap.getResult(dimIndex);
+  return nullptr;
+}
+
+// Check if `expr` is either:
+// - a dimension expr alone (implying *1), or
+// - a multiplication of dimension expr by constant.
+static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_t &constantValue) {
+  if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
+    dim = dExpr;
+    constantValue = 1;
+    return true;
+  }
+
+  auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+  if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
+    return false;
+
+  AffineExpr lhs = mulExpr.getLHS();
+  AffineExpr rhs = mulExpr.getRHS();
+
+  if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
+    if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
+      dim = dExpr;
+      constantValue = cst.getValue();
+      return true;
+    }
+  }
+  if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
+    if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
+      dim = dExpr;
+      constantValue = cst.getValue();
+      return true;
+    }
+  }
+  return false;
+}
+
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim) {
+  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+  AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
+  auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
+  if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
+    return false;
+
+  AffineExpr dim0, dim1;
+  // TODO(Abhishek-Varma): Use this information in specialize.cpp.
+  int64_t c0, c1;
+
+  if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
+      isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) {
+    // Pattern matched with dims and constants extracted.
+    AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
+    AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
+    return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr));
+  }
+  return false;
+}
+
+static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) {
+  return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim);
+}
+
+static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef<int64_t> expectedSizes) {
+  if (indexingMaps.size() != expectedSizes.size()) return false;
+
+  for (auto [indexingMap, expectedSize] : llvm::zip_equal(indexingMaps, expectedSizes)) {
+    auto affineMap = cast<AffineMapAttr>(indexingMap).getValue();
+    if (affineMap.getNumResults() != expectedSize) return false;
+  }
+  return true;
+}
+
+bool isaConv1DOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {1,1,1})) return false;
+  
+  // #map = affine_map<(d0, d1) -> (d0 + d1)>
+  // #map1 = affine_map<(d0, d1) -> (d1)>
+  // #map2 = affine_map<(d0, d1) -> (d0)>
+  return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0);
+}
+
+bool isaConv1DNwcWcfOp(LinalgOp op) {
+  if (isa<linalg::Conv1DNwcWcfOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2));
+}
+
+bool isaConv1DNcwFcwOp(LinalgOp op) {
+  if (isa<linalg::Conv1DNcwFcwOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+}
+
+bool isaDepthwiseConv1DNcwCwOp(LinalgOp op) {
+  if (isa<linalg::DepthwiseConv1DNcwCwOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2));
+}
+
+bool isaDepthwiseConv1DNwcWcOp(LinalgOp op) {
+  if (isa<linalg::DepthwiseConv1DNwcWcOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+  // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1));
+}
+
+bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) {
+  if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,4})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3));
+}
+
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
                             Value source, Value pad, bool nofold,
                             ValueRange typeDynDims) {

>From a08247c4804f504535a061c77fe782e5b75991f4 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 8 Oct 2025 07:47:06 -0500
Subject: [PATCH 08/18] Some more APIs

---
 .../include/mlir/Dialect/Linalg/Utils/Utils.h |  14 +
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 105 +------
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 277 ++++++++++++++++++
 3 files changed, 306 insertions(+), 90 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 222b66ca51708..b4955625b6dec 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -120,6 +120,20 @@ bool isaConv1DNcwFcwOp(LinalgOp op);
 bool isaDepthwiseConv1DNcwCwOp(LinalgOp op);
 bool isaDepthwiseConv1DNwcWcOp(LinalgOp op);
 bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op);
+bool isaConv2DOp(LinalgOp op);
+bool isaConv2DNhwcFhwcOp(LinalgOp op);
+bool isaConv2DNhwcHwcfOp(LinalgOp op);
+bool isaConv2DNchwFchwOp(LinalgOp op);
+bool isaConv2DNhwcFhwcQOp(LinalgOp op);
+bool isaConv2DNchwFchwQOp(LinalgOp op);
+bool isaConv2DNgchwFgchwOp(LinalgOp op);
+bool isaConv2DNgchwGfchwOp(LinalgOp op);
+bool isaConv2DNgchwGfchwQOp(LinalgOp op);
+bool isaConv2DNhwgcGfhwcOp(LinalgOp op);
+bool isaDepthwiseConv2DNchwChwOp(LinalgOp op);
+bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op);
+bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op);
+bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op);
 
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 8b51b8f13ce0d..968370c05615a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -361,10 +361,10 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
   // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-  if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1))
+  if (isaConv2DOp(genericOp))
     return "linalg.conv_2d";
   
+  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
   Block *body = genericOp.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
@@ -432,21 +432,13 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3))
+  if (isaDepthwiseConv2DNchwChwOp(genericOp))
     return "linalg.depthwise_conv_2d_nchw_chw";
   // depthwise_conv_2d_nhwc_hwc
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3))
+  if (isaDepthwiseConv2DNhwcHwcOp(genericOp))
     return "linalg.depthwise_conv_2d_nhwc_hwc";
   // conv_3d
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
@@ -511,90 +503,50 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
 static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
   ArrayAttr indexingMaps = genericOp.getIndexingMaps();
   if (indexingMaps.size() < 3) return "";
-  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
   // conv_2d_nhwc_fhwc
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  if (indexingMaps.size() == 3 &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3))
+  if (isaConv2DNhwcFhwcOp(genericOp))
     return "linalg.conv_2d_nhwc_fhwc";
   // conv_2d_nhwc_hwcf
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3))
+  if (isaConv2DNhwcHwcfOp(genericOp))
     return "linalg.conv_2d_nhwc_hwcf";
   // conv_2d_nchw_fchw
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  if (indexingMaps.size() == 3 &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
+  if (isaConv2DNchwFchwOp(genericOp))
     return "linalg.conv_2d_nchw_fchw";
   // conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps)
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  if (indexingMaps.size() == 5 &&
-      (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3))
+  if (isaConv2DNhwcFhwcQOp(genericOp))
     return "linalg.conv_2d_nhwc_fhwc_q";
   // conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps)
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  if (indexingMaps.size() == 5 &&
-      (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
+  if (isaConv2DNchwFchwQOp(genericOp))
     return "linalg.conv_2d_nchw_fchw_q";
   // depthwise_conv_2d_nhwc_hwcm
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
-  if (indexingMaps.size() == 3 &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4))
+  if (isaDepthwiseConv2DNhwcHwcmOp(genericOp))
     return "linalg.depthwise_conv_2d_nhwc_hwcm";
   // depthwise_conv_2d_nhwc_hwcm_q
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
-  if (indexingMaps.size() == 5 &&
-      (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4))
+  if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp))
     return "linalg.depthwise_conv_2d_nhwc_hwcm_q";
   return "";
 }
@@ -607,53 +559,26 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2))
+  if (isaConv2DNgchwFgchwOp(genericOp))
     return "linalg.conv_2d_ngchw_fgchw";
   // conv_2d_ngchw_gfchw
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  if (indexingMaps.size() == 3 &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2))
+  if (isaConv2DNgchwGfchwOp(genericOp))
     return "linalg.conv_2d_ngchw_gfchw";
   // conv_2d_ngchw_gfchw_q
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  if (indexingMaps.size() == 5 &&
-      (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2))
+  if (isaConv2DNgchwGfchwQOp(genericOp))
     return "linalg.conv_2d_ngchw_gfchw_q";
   // conv_2d_nhwgc_gfhwc
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4))
+  if (isaConv2DNhwgcGfhwcOp(genericOp))
     return "linalg.conv_2d_nhwgc_gfhwc";
   // depthwise_conv_3d_ncdhw_cdhw
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 12f88caf08fc7..c5bb184c726f8 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -459,6 +459,283 @@ bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) {
       matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3));
 }
 
+bool isaConv2DOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {2,2,2})) return false;
+  
+  // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+  return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1));
+}
+
+bool isaConv2DNhwcFhwcOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
+}
+
+bool isaConv2DNhwcHwcfOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
+}
+
+bool isaConv2DNchwFchwOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+}
+
+bool isaConv2DNhwcFhwcQOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
+}
+
+bool isaConv2DNchwFchwQOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+}
+
+bool isaConv2DNgchwFgchwOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2));
+}
+
+bool isaConv2DNgchwGfchwOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+  return (indexingMaps.size() == 3 &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
+}
+
+bool isaConv2DNgchwGfchwQOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
+}
+
+bool isaConv2DNhwgcGfhwcOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
+}
+
+bool isaDepthwiseConv2DNchwChwOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3));
+}
+
+bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
+}
+
+bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,5})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
+}
+
+bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,5})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
+}
+
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
                             Value source, Value pad, bool nofold,
                             ValueRange typeDynDims) {

>From 053d912f321b72cde08df10ef5cc9690248247cc Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 8 Oct 2025 07:59:55 -0500
Subject: [PATCH 09/18] Clean a bit

---
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 97 +------------------
 1 file changed, 5 insertions(+), 92 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 968370c05615a..ea94b49946545 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -343,27 +343,14 @@ static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
 }
 
 static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
-  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
-  if (indexingMaps.size() != 3) return "";
-  // depthwise_conv_1d_ncw_cw
-  // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
-  // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
-  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
   if (isaDepthwiseConv1DNcwCwOp(genericOp))
     return "linalg.depthwise_conv_1d_ncw_cw";
-  // depthwise_conv_1d_nwc_wc
-  // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
-  // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
-  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
   if (isaDepthwiseConv1DNwcWcOp(genericOp))
     return "linalg.depthwise_conv_1d_nwc_wc";
-  // conv_2d
-  // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
-  // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
-  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
   if (isaConv2DOp(genericOp))
     return "linalg.conv_2d";
   
+  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
   unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
   Block *body = genericOp.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
@@ -401,45 +388,24 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
 }
 
 static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
-  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
-  if (indexingMaps.size() != 3) return "";
-  // depthwise_conv_1d_nwc_wcm
-  // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
   if (isaDepthwiseConv1DNwcWcmOp(genericOp))
     return "linalg.depthwise_conv_1d_nwc_wcm";
-  // conv_1d_nwc_wcf
-  // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
   if (isaConv1DNwcWcfOp(genericOp))
     return "linalg.conv_1d_nwc_wcf";
-  // conv_1d_ncw_fcw
-  // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
   if (isaConv1DNcwFcwOp(genericOp))
     return "linalg.conv_1d_ncw_fcw";
   return "";
 }
 
 static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
-  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
-  if (indexingMaps.size() < 3) return "";
-  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
-  // depthwise_conv_2d_nchw_chw
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
   if (isaDepthwiseConv2DNchwChwOp(genericOp))
     return "linalg.depthwise_conv_2d_nchw_chw";
-  // depthwise_conv_2d_nhwc_hwc
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
   if (isaDepthwiseConv2DNhwcHwcOp(genericOp))
     return "linalg.depthwise_conv_2d_nhwc_hwc";
+  
+  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+  if (indexingMaps.size() < 3) return "";
+  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
   // conv_3d
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
@@ -501,83 +467,30 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
 }
 
 static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
-  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
-  if (indexingMaps.size() < 3) return "";
-  // conv_2d_nhwc_fhwc
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
   if (isaConv2DNhwcFhwcOp(genericOp))
     return "linalg.conv_2d_nhwc_fhwc";
-  // conv_2d_nhwc_hwcf
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
   if (isaConv2DNhwcHwcfOp(genericOp))
     return "linalg.conv_2d_nhwc_hwcf";
-  // conv_2d_nchw_fchw
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
   if (isaConv2DNchwFchwOp(genericOp))
     return "linalg.conv_2d_nchw_fchw";
-  // conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps)
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
-  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
   if (isaConv2DNhwcFhwcQOp(genericOp))
     return "linalg.conv_2d_nhwc_fhwc_q";
-  // conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps)
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
-  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
   if (isaConv2DNchwFchwQOp(genericOp))
     return "linalg.conv_2d_nchw_fchw_q";
-  // depthwise_conv_2d_nhwc_hwcm
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
   if (isaDepthwiseConv2DNhwcHwcmOp(genericOp))
     return "linalg.depthwise_conv_2d_nhwc_hwcm";
-  // depthwise_conv_2d_nhwc_hwcm_q
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
-  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
   if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp))
     return "linalg.depthwise_conv_2d_nhwc_hwcm_q";
   return "";
 }
 
 static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
-  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
-  if (indexingMaps.size() < 3) return "";
-  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
-  // conv_2d_ngchw_fgchw
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
   if (isaConv2DNgchwFgchwOp(genericOp))
     return "linalg.conv_2d_ngchw_fgchw";
-  // conv_2d_ngchw_gfchw
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
   if (isaConv2DNgchwGfchwOp(genericOp))
     return "linalg.conv_2d_ngchw_gfchw";
-  // conv_2d_ngchw_gfchw_q
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
-  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
   if (isaConv2DNgchwGfchwQOp(genericOp))
     return "linalg.conv_2d_ngchw_gfchw_q";
-  // conv_2d_nhwgc_gfhwc
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
   if (isaConv2DNhwgcGfhwcOp(genericOp))
     return "linalg.conv_2d_nhwgc_gfhwc";
   // depthwise_conv_3d_ncdhw_cdhw

>From 87b91ee5602ff9b031f98a980f6283d517190092 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 9 Oct 2025 02:36:54 -0500
Subject: [PATCH 10/18] Add 3D APIs

---
 .../include/mlir/Dialect/Linalg/Utils/Utils.h |   6 +
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  70 ++---------
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 117 ++++++++++++++++++
 3 files changed, 132 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index b4955625b6dec..ad5e0818b90f5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -134,6 +134,12 @@ bool isaDepthwiseConv2DNchwChwOp(LinalgOp op);
 bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op);
 bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op);
 bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op);
+bool isaConv3DOp(LinalgOp op);
+bool isaConv3DNcdhwFcdhwOp(LinalgOp op);
+bool isaConv3DNdhwcDhwcfOp(LinalgOp op);
+bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op);
+bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op);
+bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op);
 
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index ea94b49946545..6ecc6a024bed8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -406,13 +406,7 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
   ArrayAttr indexingMaps = genericOp.getIndexingMaps();
   if (indexingMaps.size() < 3) return "";
   unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
-  // conv_3d
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
-  if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2))
+  if (isaConv3DOp(genericOp))
     return "linalg.conv_3d";
 
   Block *body = genericOp.getBlock();
@@ -493,29 +487,14 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
     return "linalg.conv_2d_ngchw_gfchw_q";
   if (isaConv2DNhwgcGfhwcOp(genericOp))
     return "linalg.conv_2d_nhwgc_gfhwc";
-  // depthwise_conv_3d_ncdhw_cdhw
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4))
+  if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp))
     return "linalg.depthwise_conv_3d_ncdhw_cdhw";
-  // depthwise_conv_3d_ndhwc_dhwc
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4))
+  if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp))
     return "linalg.depthwise_conv_3d_ndhwc_dhwc";
 
+  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+  if (indexingMaps.size() < 3) return "";
+  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
   Block *body = genericOp.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
@@ -541,42 +520,11 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
 }
 
 static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
-  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
-  if (indexingMaps.size() < 3) return "";
-  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
-  // conv_3d_ncdhw_fcdhw
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
+  if (isaConv3DNcdhwFcdhwOp(genericOp))
     return "linalg.conv_3d_ncdhw_fcdhw";
-  // conv_3d_ndhwc_dhwcf
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4))
+  if (isaConv3DNdhwcDhwcfOp(genericOp))
     return "linalg.conv_3d_ndhwc_dhwcf";
-  // depthwise_conv_3d_ndhwc_dhwcm
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5))
+  if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp))
     return "linalg.depthwise_conv_3d_ndhwc_dhwcm";
   return "";
 }
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index c5bb184c726f8..b3e79e8c1a409 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -736,6 +736,123 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
       matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
 }
 
+bool isaConv3DOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
+  
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
+  return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2));
+}
+
+bool isaConv3DNcdhwFcdhwOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+}
+
+bool isaConv3DNdhwcDhwcfOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
+}
+
+bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,6})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5));
+}
+
+bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4));
+}
+
+bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) {
+  if (isa<linalg::Conv1DOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4));
+}
+
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
                             Value source, Value pad, bool nofold,
                             ValueRange typeDynDims) {

>From e5acca4e6d5c76cfd00688d487c16f6082ecd3a7 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 9 Oct 2025 02:43:28 -0500
Subject: [PATCH 11/18] Fix the NamedOp versions

---
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 40 ++++++++++++-------------
 1 file changed, 20 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index b3e79e8c1a409..2d6d51d858853 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -460,7 +460,7 @@ bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) {
 }
 
 bool isaConv2DOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::Conv2DOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -475,7 +475,7 @@ bool isaConv2DOp(LinalgOp op) {
 }
 
 bool isaConv2DNhwcFhwcOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::Conv2DNhwcFhwcOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -494,7 +494,7 @@ bool isaConv2DNhwcFhwcOp(LinalgOp op) {
 }
 
 bool isaConv2DNhwcHwcfOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::Conv2DNhwcHwcfOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -513,7 +513,7 @@ bool isaConv2DNhwcHwcfOp(LinalgOp op) {
 }
 
 bool isaConv2DNchwFchwOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::Conv2DNchwFchwOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -532,7 +532,7 @@ bool isaConv2DNchwFchwOp(LinalgOp op) {
 }
 
 bool isaConv2DNhwcFhwcQOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::Conv2DNhwcFhwcQOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -552,7 +552,7 @@ bool isaConv2DNhwcFhwcQOp(LinalgOp op) {
 }
 
 bool isaConv2DNchwFchwQOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::Conv2DNchwFchwQOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -572,7 +572,7 @@ bool isaConv2DNchwFchwQOp(LinalgOp op) {
 }
 
 bool isaConv2DNgchwFgchwOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::Conv2DNgchwFgchwOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -593,7 +593,7 @@ bool isaConv2DNgchwFgchwOp(LinalgOp op) {
 }
 
 bool isaConv2DNgchwGfchwOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::Conv2DNgchwGfchwOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -615,7 +615,7 @@ bool isaConv2DNgchwGfchwOp(LinalgOp op) {
 }
 
 bool isaConv2DNgchwGfchwQOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::Conv2DNgchwGfchwQOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -637,7 +637,7 @@ bool isaConv2DNgchwGfchwQOp(LinalgOp op) {
 }
 
 bool isaConv2DNhwgcGfhwcOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::Conv2DNhwgcGfhwcOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -658,7 +658,7 @@ bool isaConv2DNhwgcGfhwcOp(LinalgOp op) {
 }
 
 bool isaDepthwiseConv2DNchwChwOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::DepthwiseConv2DNchwChwOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -677,7 +677,7 @@ bool isaDepthwiseConv2DNchwChwOp(LinalgOp op) {
 }
 
 bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -696,7 +696,7 @@ bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op) {
 }
 
 bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -716,7 +716,7 @@ bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op) {
 }
 
 bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -737,7 +737,7 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
 }
 
 bool isaConv3DOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::Conv3DOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -753,7 +753,7 @@ bool isaConv3DOp(LinalgOp op) {
 }
 
 bool isaConv3DNcdhwFcdhwOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::Conv3DNcdhwFcdhwOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -773,7 +773,7 @@ bool isaConv3DNcdhwFcdhwOp(LinalgOp op) {
 }
 
 bool isaConv3DNdhwcDhwcfOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::Conv3DNdhwcDhwcfOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -793,7 +793,7 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op) {
 }
 
 bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -814,7 +814,7 @@ bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) {
 }
 
 bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::DepthwiseConv3DNcdhwCdhwOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 
@@ -834,7 +834,7 @@ bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) {
 }
 
 bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::DepthwiseConv3DNdhwcDhwcOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
 

>From 535f7e95288274f104b5c6fda01012016613d3b3 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 9 Oct 2025 03:22:39 -0500
Subject: [PATCH 12/18] Pooling ops'

---
 .../include/mlir/Dialect/Linalg/Utils/Utils.h |  15 +
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 243 ++-----------
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 328 ++++++++++++++++++
 3 files changed, 373 insertions(+), 213 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index ad5e0818b90f5..1a1b70d3eb979 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -140,6 +140,21 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op);
 bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op);
 bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op);
 bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op);
+bool isaPoolingNchwMaxOp(LinalgOp op);
+bool isaPoolingNchwSumOp(LinalgOp op);
+bool isaPoolingNhwcMaxOp(LinalgOp op);
+bool isaPoolingNhwcMinOp(LinalgOp op);
+bool isaPoolingNhwcSumOp(LinalgOp op);
+bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op);
+bool isaPoolingNhwcMinUnsignedOp(LinalgOp op);
+bool isaPoolingNcwMaxOp(LinalgOp op);
+bool isaPoolingNcwSumOp(LinalgOp op);
+bool isaPoolingNwcMaxOp(LinalgOp op);
+bool isaPoolingNwcMinOp(LinalgOp op);
+bool isaPoolingNwcSumOp(LinalgOp op);
+bool isaPoolingNdhwcMaxOp(LinalgOp op);
+bool isaPoolingNdhwcMinOp(LinalgOp op);
+bool isaPoolingNdhwcSumOp(LinalgOp op);
 
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 6ecc6a024bed8..aef3a1480d289 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,106 +237,6 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
 }
 
-/// Utility to match block body for linalg.pool* ops.
-template <typename... OpTypes>
-static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
-  Operation *defOp = yieldVal.getDefiningOp();
-  // if (!defOp) return false;
-  if (!(isa_and_present<OpTypes>(defOp) || ...)) return false;
-
-  BlockArgument lhsArg =  dyn_cast<BlockArgument>(defOp->getOperand(0));
-  BlockArgument rhsArg =  dyn_cast<BlockArgument>(defOp->getOperand(1));
-  if (!lhsArg || !rhsArg) return false;
-  return true;
-}
-
-static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
-  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, body);
-}
-
-static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
-  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, body);
-}
-
-static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
-  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, body);
-}
-
-static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
-  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, body);
-}
-
-static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
-  return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
-}
-
-static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
-                                        uint32_t mapIndex, uint32_t dimIndex) {
-  auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
-  if (dimIndex < affineMap.getNumResults())
-    return affineMap.getResult(dimIndex);
-  return nullptr;
-}
-
-// Check if `expr` is either:
-// - a dimension expr alone (implying *1), or
-// - a multiplication of dimension expr by constant.
-bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_t &constantValue) {
-  if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
-    dim = dExpr;
-    constantValue = 1;
-    return true;
-  }
-
-  auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
-  if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
-    return false;
-
-  AffineExpr lhs = mulExpr.getLHS();
-  AffineExpr rhs = mulExpr.getRHS();
-
-  if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
-    if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
-      dim = dExpr;
-      constantValue = cst.getValue();
-      return true;
-    }
-  }
-  if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
-    if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
-      dim = dExpr;
-      constantValue = cst.getValue();
-      return true;
-    }
-  }
-  return false;
-}
-
-bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim) {
-  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
-  AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
-  auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
-  if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
-    return false;
-
-  AffineExpr dim0, dim1;
-  // TODO(Abhishek-Varma): Use this information in specialize.cpp.
-  int64_t c0, c1;
-
-  if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
-      isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) {
-    // Pattern matched with dims and constants extracted.
-    AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
-    AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
-    return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr));
-  }
-  return false;
-}
-
-bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) {
-  return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim);
-}
-
 static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
   if (isaConv1DOp(genericOp)) return "linalg.conv_1d";
   return "";
@@ -349,41 +249,16 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
     return "linalg.depthwise_conv_1d_nwc_wc";
   if (isaConv2DOp(genericOp))
     return "linalg.conv_2d";
-  
-  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
-  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
-  Block *body = genericOp.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  // pooling_ncw_max
-  // pooling_ncw_sum
-  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
-  // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
-  // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2)) {
-    if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
-      return "linalg.pooling_ncw_max";
-    if (bodyMatcherForSumPoolOps(yieldVal, body))
-      return "linalg.pooling_ncw_sum";
-  }
-  // pooling_nwc_max
-  // pooling_nwc_min
-  // pooling_nwc_sum
-  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
-  // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
-  // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2)) {
-    if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
-      return "linalg.pooling_nwc_max";
-    if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
-      return "linalg.pooling_nwc_min";
-    if (bodyMatcherForSumPoolOps(yieldVal, body))
-      return "linalg.pooling_nwc_sum";
-  }
+  if (isaPoolingNcwMaxOp(genericOp))
+    return "linalg.pooling_ncw_max";
+  if (isaPoolingNcwSumOp(genericOp))
+    return "linalg.pooling_ncw_sum";
+  if (isaPoolingNwcMaxOp(genericOp))
+    return "linalg.pooling_nwc_max";
+  if (isaPoolingNwcMinOp(genericOp))
+    return "linalg.pooling_nwc_min";
+  if (isaPoolingNwcSumOp(genericOp))
+    return "linalg.pooling_nwc_sum";
   return "";
 }
 
@@ -402,61 +277,22 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
     return "linalg.depthwise_conv_2d_nchw_chw";
   if (isaDepthwiseConv2DNhwcHwcOp(genericOp))
     return "linalg.depthwise_conv_2d_nhwc_hwc";
-  
-  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
-  if (indexingMaps.size() < 3) return "";
-  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
   if (isaConv3DOp(genericOp))
     return "linalg.conv_3d";
-
-  Block *body = genericOp.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  // pooling_nchw_max
-  // pooling_nchw_sum
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3)) {
-    if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
-      return "linalg.pooling_nchw_max";
-    if (bodyMatcherForSumPoolOps(yieldVal, body))
-      return "linalg.pooling_nchw_sum";
-  }
-  // pooling_nhwc_max
-  // pooling_nhwc_min
-  // pooling_nhwc_sum
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
-  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
-  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) {
-    if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
-      return "linalg.pooling_nhwc_max";
-    if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
-      return "linalg.pooling_nhwc_min";
-    if (bodyMatcherForSumPoolOps(yieldVal, body))
-      return "linalg.pooling_nhwc_sum";
-  }
-  // pooling_nhwc_max_unsigned
-  // pooling_nhwc_min_unsigned
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) {
-    if (bodyMatcherForMaxUnsignedPoolOps(yieldVal, body))
-      return "linalg.pooling_nhwc_max_unsigned";
-    if (bodyMatcherForMinUnsignedPoolOps(yieldVal, body))
-      return "linalg.pooling_nhwc_min_unsigned";
-  }
+  if (isaPoolingNchwMaxOp(genericOp))
+    return "linalg.pooling_nchw_max";
+  if (isaPoolingNchwSumOp(genericOp))
+    return "linalg.pooling_nchw_sum";
+  if (isaPoolingNhwcMaxOp(genericOp))
+    return "linalg.pooling_nhwc_max";
+  if (isaPoolingNhwcMinOp(genericOp))
+    return "linalg.pooling_nhwc_min";
+  if (isaPoolingNhwcSumOp(genericOp))
+    return "linalg.pooling_nhwc_sum";
+  if (isaPoolingNhwcMaxUnsignedOp(genericOp))
+    return "linalg.pooling_nhwc_max_unsigned";
+  if (isaPoolingNhwcMinUnsignedOp(genericOp))
+    return "linalg.pooling_nhwc_min_unsigned";
   return "";
 }
 
@@ -491,31 +327,12 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
     return "linalg.depthwise_conv_3d_ncdhw_cdhw";
   if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp))
     return "linalg.depthwise_conv_3d_ndhwc_dhwc";
-
-  ArrayAttr indexingMaps = genericOp.getIndexingMaps();
-  if (indexingMaps.size() < 3) return "";
-  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
-  Block *body = genericOp.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  // pooling_ndhwc_max
-  // pooling_ndhwc_min
-  // pooling_ndhwc_sum
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
-  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
-  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4)) {
-    if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
-      return "linalg.pooling_ndhwc_max";
-    if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
-      return "linalg.pooling_ndhwc_min";
-    if (bodyMatcherForSumPoolOps(yieldVal, body))
-      return "linalg.pooling_ndhwc_sum";
-  }
+  if (isaPoolingNdhwcMaxOp(genericOp))
+    return "linalg.pooling_ndhwc_max";
+  if (isaPoolingNdhwcMinOp(genericOp))
+    return "linalg.pooling_ndhwc_min";
+  if (isaPoolingNdhwcSumOp(genericOp))
+    return "linalg.pooling_ndhwc_sum";
   return "";
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 2d6d51d858853..127e61d7db050 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -853,6 +853,334 @@ bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) {
       matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4));
 }
 
+bool isaPoolingNchwMaxOp(LinalgOp op) {
+  if (isa<linalg::PoolingNchwMaxOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
+  
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3) &&
+      bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNchwSumOp(LinalgOp op) {
+  if (isa<linalg::PoolingNchwSumOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
+  
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3) &&
+      bodyMatcherForSumPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNhwcMaxOp(LinalgOp op) {
+  if (isa<linalg::PoolingNhwcMaxOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
+  
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+      bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNhwcMinOp(LinalgOp op) {
+  if (isa<linalg::PoolingNhwcMinOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
+  
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+      bodyMatcherForMinSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNhwcSumOp(LinalgOp op) {
+  if (isa<linalg::PoolingNhwcSumOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
+  
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+      bodyMatcherForSumPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op) {
+  if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
+  
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+      bodyMatcherForMaxUnsignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNhwcMinUnsignedOp(LinalgOp op) {
+  if (isa<linalg::PoolingNhwcMinUnsignedOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
+  
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+      bodyMatcherForMinUnsignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNcwMaxOp(LinalgOp op) {
+  if (isa<linalg::PoolingNcwMaxOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
+  
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
+  // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
+  // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
+      bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNcwSumOp(LinalgOp op) {
+  if (isa<linalg::PoolingNcwSumOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
+  
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
+  // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
+  // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
+      bodyMatcherForSumPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNwcMaxOp(LinalgOp op) {
+  if (isa<linalg::PoolingNwcMaxOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
+  
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+  // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
+  // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+      bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNwcMinOp(LinalgOp op) {
+  if (isa<linalg::PoolingNwcMinOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
+  
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+  // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
+  // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+      bodyMatcherForMinSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNwcSumOp(LinalgOp op) {
+  if (isa<linalg::PoolingNwcSumOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
+  
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+  // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
+  // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+      bodyMatcherForSumPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNdhwcMaxOp(LinalgOp op) {
+  if (isa<linalg::PoolingNdhwcMaxOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false;
+  
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
+  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+      bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNdhwcMinOp(LinalgOp op) {
+  if (isa<linalg::PoolingNdhwcMinOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false;
+  
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
+  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+      bodyMatcherForMinSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNdhwcSumOp(LinalgOp op) {
+  if (isa<linalg::PoolingNdhwcSumOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false;
+  
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
+  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+      bodyMatcherForSumPoolOps(yieldVal, body));
+}
+
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
                             Value source, Value pad, bool nofold,
                             ValueRange typeDynDims) {

>From b06ba750c040a5b2506271ccf96d60fceb02cdef Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 9 Oct 2025 03:30:56 -0500
Subject: [PATCH 13/18] Updated maps

---
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 66 ++++++++++++-------------
 1 file changed, 33 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 127e61d7db050..e847f2cf3aef2 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -909,9 +909,9 @@ bool isaPoolingNhwcMaxOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
-  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
-  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
   return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
@@ -931,9 +931,9 @@ bool isaPoolingNhwcMinOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
-  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
-  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
   return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
@@ -953,9 +953,9 @@ bool isaPoolingNhwcSumOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
-  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
-  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
   return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
@@ -1019,9 +1019,9 @@ bool isaPoolingNcwMaxOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
-  // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
-  // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
   return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
@@ -1040,9 +1040,9 @@ bool isaPoolingNcwSumOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
-  // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
-  // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
   return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
@@ -1061,9 +1061,9 @@ bool isaPoolingNwcMaxOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
-  // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
-  // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+  // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
   return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
@@ -1082,9 +1082,9 @@ bool isaPoolingNwcMinOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
-  // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
-  // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+  // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
   return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
@@ -1103,9 +1103,9 @@ bool isaPoolingNwcSumOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
-  // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
-  // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+  // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
   return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
@@ -1124,9 +1124,9 @@ bool isaPoolingNdhwcMaxOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
-  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
-  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
   return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
@@ -1147,9 +1147,9 @@ bool isaPoolingNdhwcMinOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
-  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
-  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
   return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
@@ -1170,9 +1170,9 @@ bool isaPoolingNdhwcSumOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
-  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
-  // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
   return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&

>From 5aeb3716ff7ddb376bb4a9db962431dba1b55b56 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 9 Oct 2025 04:42:10 -0500
Subject: [PATCH 14/18] Missing ops

---
 .../include/mlir/Dialect/Linalg/Utils/Utils.h |  4 +
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 16 ++++
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 83 +++++++++++++++++++
 3 files changed, 103 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 1a1b70d3eb979..2f7868bd55182 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -128,15 +128,19 @@ bool isaConv2DNhwcFhwcQOp(LinalgOp op);
 bool isaConv2DNchwFchwQOp(LinalgOp op);
 bool isaConv2DNgchwFgchwOp(LinalgOp op);
 bool isaConv2DNgchwGfchwOp(LinalgOp op);
+bool isaConv2DNhwcHwcfQOp(LinalgOp op);
+bool isaConv2DNhwgcGfhwcQOp(LinalgOp op);
 bool isaConv2DNgchwGfchwQOp(LinalgOp op);
 bool isaConv2DNhwgcGfhwcOp(LinalgOp op);
 bool isaDepthwiseConv2DNchwChwOp(LinalgOp op);
 bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op);
 bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op);
+bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op);
 bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op);
 bool isaConv3DOp(LinalgOp op);
 bool isaConv3DNcdhwFcdhwOp(LinalgOp op);
 bool isaConv3DNdhwcDhwcfOp(LinalgOp op);
+bool isaConv3DNdhwcDhwcfQOp(LinalgOp op);
 bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op);
 bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op);
 bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index aef3a1480d289..031cb3b919b96 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -277,6 +277,8 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
     return "linalg.depthwise_conv_2d_nchw_chw";
   if (isaDepthwiseConv2DNhwcHwcOp(genericOp))
     return "linalg.depthwise_conv_2d_nhwc_hwc";
+  if (isaDepthwiseConv2DNhwcHwcQOp(genericOp))
+    return "linalg.depthwise_conv_2d_nhwc_hwc_q";
   if (isaConv3DOp(genericOp))
     return "linalg.conv_3d";
   if (isaPoolingNchwMaxOp(genericOp))
@@ -307,6 +309,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
     return "linalg.conv_2d_nhwc_fhwc_q";
   if (isaConv2DNchwFchwQOp(genericOp))
     return "linalg.conv_2d_nchw_fchw_q";
+  if (isaConv2DNhwcHwcfQOp(genericOp))
+    return "linalg.conv_2d_nhwc_hwcf_q";
   if (isaDepthwiseConv2DNhwcHwcmOp(genericOp))
     return "linalg.depthwise_conv_2d_nhwc_hwcm";
   if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp))
@@ -323,6 +327,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
     return "linalg.conv_2d_ngchw_gfchw_q";
   if (isaConv2DNhwgcGfhwcOp(genericOp))
     return "linalg.conv_2d_nhwgc_gfhwc";
+  if (isaConv2DNhwgcGfhwcQOp(genericOp))
+    return "linalg.conv_2d_nhwgc_gfhwc_q";
   if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp))
     return "linalg.depthwise_conv_3d_ncdhw_cdhw";
   if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp))
@@ -341,6 +347,8 @@ static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
     return "linalg.conv_3d_ncdhw_fcdhw";
   if (isaConv3DNdhwcDhwcfOp(genericOp))
     return "linalg.conv_3d_ndhwc_dhwcf";
+  if (isaConv3DNdhwcDhwcfQOp(genericOp))
+    return "linalg.conv_3d_ndhwc_dhwcf_q";
   if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp))
     return "linalg.depthwise_conv_3d_ndhwc_dhwcm";
   return "";
@@ -412,6 +420,10 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
     namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwGfchwQOp>(genericOp, resultTypes, inputs, outputs);
   } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc") {
     namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwgcGfhwcOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_2d_nhwc_hwcf_q") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcHwcfQOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc_q") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwgcGfhwcQOp>(genericOp, resultTypes, inputs, outputs);
   } else if (convKind == "linalg.depthwise_conv_2d_nchw_chw") {
     namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNchwChwOp>(genericOp, resultTypes, inputs, outputs);
   } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc") {
@@ -420,12 +432,16 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
     namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmOp>(genericOp, resultTypes, inputs, outputs);
   } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm_q") {
     namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc_q") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcQOp>(genericOp, resultTypes, inputs, outputs);
   } else if (convKind == "linalg.conv_3d") {
     namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DOp>(genericOp, resultTypes, inputs, outputs);
   } else if (convKind == "linalg.conv_3d_ncdhw_fcdhw") {
     namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNcdhwFcdhwOp>(genericOp, resultTypes, inputs, outputs);
   } else if (convKind == "linalg.conv_3d_ndhwc_dhwcf") {
     namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNdhwcDhwcfOp>(genericOp, resultTypes, inputs, outputs);
+  } else if (convKind == "linalg.conv_3d_ndhwc_dhwcf_q") {
+    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNdhwcDhwcfQOp>(genericOp, resultTypes, inputs, outputs);
   } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwcm") {
     namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(genericOp, resultTypes, inputs, outputs);
   } else if (convKind == "linalg.depthwise_conv_3d_ncdhw_cdhw") {
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index e847f2cf3aef2..b239f62a7049d 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -614,6 +614,48 @@ bool isaConv2DNgchwGfchwOp(LinalgOp op) {
       matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
 }
 
+bool isaConv2DNhwcHwcfQOp(LinalgOp op) {
+  if (isa<linalg::Conv2DNhwcHwcfQOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
+}
+
+bool isaConv2DNhwgcGfhwcQOp(LinalgOp op) {
+  if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
+}
+
 bool isaConv2DNgchwGfchwQOp(LinalgOp op) {
   if (isa<linalg::Conv2DNgchwGfchwQOp>(op)) return true;
 
@@ -736,6 +778,26 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
       matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
 }
 
+bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op) {
+  if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,0,0,4})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
+}
+
 bool isaConv3DOp(LinalgOp op) {
   if (isa<linalg::Conv3DOp>(op)) return true;
 
@@ -792,6 +854,27 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op) {
       matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
 }
 
+bool isaConv3DNdhwcDhwcfQOp(LinalgOp op) {
+  if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op)) return true;
+
+  if (!isaConvolutionOpInterface(op)) return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
+  
+  unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> ()>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+      matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
+}
+
 bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) {
   if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op)) return true;
 

>From f1b8e80ff65c4db082b634eadc3976c35dc4ccac Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 14 Oct 2025 04:37:28 -0500
Subject: [PATCH 15/18] Make use of dilations/strides info

---
 .../include/mlir/Dialect/Linalg/Utils/Utils.h |  84 +--
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 355 +++++-------
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 546 ++++++++++++------
 3 files changed, 548 insertions(+), 437 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 2f7868bd55182..44ebc101d7c37 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -115,50 +115,50 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
 //===----------------------------------------------------------------------===//
 
 bool isaConv1DOp(LinalgOp op);
-bool isaConv1DNwcWcfOp(LinalgOp op);
-bool isaConv1DNcwFcwOp(LinalgOp op);
-bool isaDepthwiseConv1DNcwCwOp(LinalgOp op);
-bool isaDepthwiseConv1DNwcWcOp(LinalgOp op);
-bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op);
+bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
 bool isaConv2DOp(LinalgOp op);
-bool isaConv2DNhwcFhwcOp(LinalgOp op);
-bool isaConv2DNhwcHwcfOp(LinalgOp op);
-bool isaConv2DNchwFchwOp(LinalgOp op);
-bool isaConv2DNhwcFhwcQOp(LinalgOp op);
-bool isaConv2DNchwFchwQOp(LinalgOp op);
-bool isaConv2DNgchwFgchwOp(LinalgOp op);
-bool isaConv2DNgchwGfchwOp(LinalgOp op);
-bool isaConv2DNhwcHwcfQOp(LinalgOp op);
-bool isaConv2DNhwgcGfhwcQOp(LinalgOp op);
-bool isaConv2DNgchwGfchwQOp(LinalgOp op);
-bool isaConv2DNhwgcGfhwcOp(LinalgOp op);
-bool isaDepthwiseConv2DNchwChwOp(LinalgOp op);
-bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op);
-bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op);
-bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op);
-bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op);
+bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
 bool isaConv3DOp(LinalgOp op);
-bool isaConv3DNcdhwFcdhwOp(LinalgOp op);
-bool isaConv3DNdhwcDhwcfOp(LinalgOp op);
-bool isaConv3DNdhwcDhwcfQOp(LinalgOp op);
-bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op);
-bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op);
-bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op);
-bool isaPoolingNchwMaxOp(LinalgOp op);
-bool isaPoolingNchwSumOp(LinalgOp op);
-bool isaPoolingNhwcMaxOp(LinalgOp op);
-bool isaPoolingNhwcMinOp(LinalgOp op);
-bool isaPoolingNhwcSumOp(LinalgOp op);
-bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op);
-bool isaPoolingNhwcMinUnsignedOp(LinalgOp op);
-bool isaPoolingNcwMaxOp(LinalgOp op);
-bool isaPoolingNcwSumOp(LinalgOp op);
-bool isaPoolingNwcMaxOp(LinalgOp op);
-bool isaPoolingNwcMinOp(LinalgOp op);
-bool isaPoolingNwcSumOp(LinalgOp op);
-bool isaPoolingNdhwcMaxOp(LinalgOp op);
-bool isaPoolingNdhwcMinOp(LinalgOp op);
-bool isaPoolingNdhwcSumOp(LinalgOp op);
+bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
 
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 031cb3b919b96..94dfbcc15d055 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,250 +237,169 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
 }
 
-static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
-  if (isaConv1DOp(genericOp)) return "linalg.conv_1d";
-  return "";
+template <typename ConvOpTy>
+static FailureOr<LinalgOp> specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
+  SmallVector<Value> inputs = genericOp.getDpsInputs();
+  ValueRange outputs = genericOp.getDpsInits();
+  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+  SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
+                                      ? TypeRange(ValueRange(outputs))
+                                      : TypeRange{};
+  LinalgOp namedOp;
+  if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> || std::is_same_v<ConvOpTy, linalg::Conv2DOp> || std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes, inputs, outputs);
+  } else {
+    Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
+    Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
+    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+  }
+  return namedOp;
 }
 
-static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
-  if (isaDepthwiseConv1DNcwCwOp(genericOp))
-    return "linalg.depthwise_conv_1d_ncw_cw";
-  if (isaDepthwiseConv1DNwcWcOp(genericOp))
-    return "linalg.depthwise_conv_1d_nwc_wc";
+static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConv1DOp(genericOp)) return specializeToConvOp<linalg::Conv1DOp>(rewriter, genericOp, dilations, strides);
+  return failure();
+}
+
+static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaDepthwiseConv1DNcwCwOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv1DNcwCwOp>(rewriter, genericOp, dilations, strides);
+  if (isaDepthwiseConv1DNwcWcOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(rewriter, genericOp, dilations, strides);
   if (isaConv2DOp(genericOp))
-    return "linalg.conv_2d";
-  if (isaPoolingNcwMaxOp(genericOp))
-    return "linalg.pooling_ncw_max";
-  if (isaPoolingNcwSumOp(genericOp))
-    return "linalg.pooling_ncw_sum";
-  if (isaPoolingNwcMaxOp(genericOp))
-    return "linalg.pooling_nwc_max";
-  if (isaPoolingNwcMinOp(genericOp))
-    return "linalg.pooling_nwc_min";
-  if (isaPoolingNwcSumOp(genericOp))
-    return "linalg.pooling_nwc_sum";
-  return "";
+    return specializeToConvOp<linalg::Conv2DOp>(rewriter, genericOp, dilations, strides);
+  if (isaPoolingNcwMaxOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNcwMaxOp>(rewriter, genericOp, dilations, strides);
+  if (isaPoolingNcwSumOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNcwSumOp>(rewriter, genericOp, dilations, strides);
+  if (isaPoolingNwcMaxOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNwcMaxOp>(rewriter, genericOp, dilations, strides);
+  if (isaPoolingNwcMinOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNwcMinOp>(rewriter, genericOp, dilations, strides);
+  if (isaPoolingNwcSumOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNwcSumOp>(rewriter, genericOp, dilations, strides);
+  return failure();
 }
 
-static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
-  if (isaDepthwiseConv1DNwcWcmOp(genericOp))
-    return "linalg.depthwise_conv_1d_nwc_wcm";
-  if (isaConv1DNwcWcfOp(genericOp))
-    return "linalg.conv_1d_nwc_wcf";
-  if (isaConv1DNcwFcwOp(genericOp))
-    return "linalg.conv_1d_ncw_fcw";
-  return "";
+static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaDepthwiseConv1DNwcWcmOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv1DNwcWcmOp>(rewriter, genericOp, dilations, strides);
+  if (isaConv1DNwcWcfOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv1DNwcWcfOp>(rewriter, genericOp, dilations, strides);
+  if (isaConv1DNcwFcwOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv1DNcwFcwOp>(rewriter, genericOp, dilations, strides);
+  return failure();
 }
 
-static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
-  if (isaDepthwiseConv2DNchwChwOp(genericOp))
-    return "linalg.depthwise_conv_2d_nchw_chw";
-  if (isaDepthwiseConv2DNhwcHwcOp(genericOp))
-    return "linalg.depthwise_conv_2d_nhwc_hwc";
-  if (isaDepthwiseConv2DNhwcHwcQOp(genericOp))
-    return "linalg.depthwise_conv_2d_nhwc_hwc_q";
+static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaDepthwiseConv2DNchwChwOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(rewriter, genericOp, dilations, strides);
+  if (isaDepthwiseConv2DNhwcHwcOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcOp>(rewriter, genericOp, dilations, strides);
+  if (isaDepthwiseConv2DNhwcHwcQOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcQOp>(rewriter, genericOp, dilations, strides);
   if (isaConv3DOp(genericOp))
-    return "linalg.conv_3d";
-  if (isaPoolingNchwMaxOp(genericOp))
-    return "linalg.pooling_nchw_max";
-  if (isaPoolingNchwSumOp(genericOp))
-    return "linalg.pooling_nchw_sum";
-  if (isaPoolingNhwcMaxOp(genericOp))
-    return "linalg.pooling_nhwc_max";
-  if (isaPoolingNhwcMinOp(genericOp))
-    return "linalg.pooling_nhwc_min";
-  if (isaPoolingNhwcSumOp(genericOp))
-    return "linalg.pooling_nhwc_sum";
-  if (isaPoolingNhwcMaxUnsignedOp(genericOp))
-    return "linalg.pooling_nhwc_max_unsigned";
-  if (isaPoolingNhwcMinUnsignedOp(genericOp))
-    return "linalg.pooling_nhwc_min_unsigned";
-  return "";
+    return specializeToConvOp<linalg::Conv3DOp>(rewriter, genericOp, dilations, strides);
+  if (isaPoolingNchwMaxOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNchwMaxOp>(rewriter, genericOp, dilations, strides);
+  if (isaPoolingNchwSumOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNchwSumOp>(rewriter, genericOp, dilations, strides);
+  if (isaPoolingNhwcMaxOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp, dilations, strides);
+  if (isaPoolingNhwcMinOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp, dilations, strides);
+  if (isaPoolingNhwcSumOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp, dilations, strides);
+  if (isaPoolingNhwcMaxUnsignedOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(rewriter, genericOp, dilations, strides);
+  if (isaPoolingNhwcMinUnsignedOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(rewriter, genericOp, dilations, strides);
+  return failure();
 }
 
-static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
-  if (isaConv2DNhwcFhwcOp(genericOp))
-    return "linalg.conv_2d_nhwc_fhwc";
-  if (isaConv2DNhwcHwcfOp(genericOp))
-    return "linalg.conv_2d_nhwc_hwcf";
-  if (isaConv2DNchwFchwOp(genericOp))
-    return "linalg.conv_2d_nchw_fchw";
-  if (isaConv2DNhwcFhwcQOp(genericOp))
-    return "linalg.conv_2d_nhwc_fhwc_q";
-  if (isaConv2DNchwFchwQOp(genericOp))
-    return "linalg.conv_2d_nchw_fchw_q";
-  if (isaConv2DNhwcHwcfQOp(genericOp))
-    return "linalg.conv_2d_nhwc_hwcf_q";
-  if (isaDepthwiseConv2DNhwcHwcmOp(genericOp))
-    return "linalg.depthwise_conv_2d_nhwc_hwcm";
-  if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp))
-    return "linalg.depthwise_conv_2d_nhwc_hwcm_q";
-  return "";
+static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConv2DNhwcFhwcOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv2DNhwcFhwcOp>(rewriter, genericOp, dilations, strides);
+  if (isaConv2DNhwcHwcfOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv2DNhwcHwcfOp>(rewriter, genericOp, dilations, strides);
+  if (isaConv2DNchwFchwOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv2DNchwFchwOp>(rewriter, genericOp, dilations, strides);
+  if (isaConv2DNhwcFhwcQOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv2DNhwcFhwcQOp>(rewriter, genericOp, dilations, strides);
+  if (isaConv2DNchwFchwQOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv2DNchwFchwQOp>(rewriter, genericOp, dilations, strides);
+  if (isaConv2DNhwcHwcfQOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv2DNhwcHwcfQOp>(rewriter, genericOp, dilations, strides);
+  if (isaDepthwiseConv2DNhwcHwcmOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmOp>(rewriter, genericOp, dilations, strides);
+  if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(rewriter, genericOp, dilations, strides);
+  return failure();
 }
 
-static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
-  if (isaConv2DNgchwFgchwOp(genericOp))
-    return "linalg.conv_2d_ngchw_fgchw";
-  if (isaConv2DNgchwGfchwOp(genericOp))
-    return "linalg.conv_2d_ngchw_gfchw";
-  if (isaConv2DNgchwGfchwQOp(genericOp))
-    return "linalg.conv_2d_ngchw_gfchw_q";
-  if (isaConv2DNhwgcGfhwcOp(genericOp))
-    return "linalg.conv_2d_nhwgc_gfhwc";
-  if (isaConv2DNhwgcGfhwcQOp(genericOp))
-    return "linalg.conv_2d_nhwgc_gfhwc_q";
-  if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp))
-    return "linalg.depthwise_conv_3d_ncdhw_cdhw";
-  if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp))
-    return "linalg.depthwise_conv_3d_ndhwc_dhwc";
-  if (isaPoolingNdhwcMaxOp(genericOp))
-    return "linalg.pooling_ndhwc_max";
-  if (isaPoolingNdhwcMinOp(genericOp))
-    return "linalg.pooling_ndhwc_min";
-  if (isaPoolingNdhwcSumOp(genericOp))
-    return "linalg.pooling_ndhwc_sum";
-  return "";
+static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConv2DNgchwFgchwOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv2DNgchwFgchwOp>(rewriter, genericOp, dilations, strides);
+  if (isaConv2DNgchwGfchwOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv2DNgchwGfchwOp>(rewriter, genericOp, dilations, strides);
+  if (isaConv2DNgchwGfchwQOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv2DNgchwGfchwQOp>(rewriter, genericOp, dilations, strides);
+  if (isaConv2DNhwgcGfhwcOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv2DNhwgcGfhwcOp>(rewriter, genericOp, dilations, strides);
+  if (isaConv2DNhwgcGfhwcQOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv2DNhwgcGfhwcQOp>(rewriter, genericOp, dilations, strides);
+  if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(rewriter, genericOp, dilations, strides);
+  if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(rewriter, genericOp, dilations, strides);
+  if (isaPoolingNdhwcMaxOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNdhwcMaxOp>(rewriter, genericOp, dilations, strides);
+  if (isaPoolingNdhwcMinOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNdhwcMinOp>(rewriter, genericOp, dilations, strides);
+  if (isaPoolingNdhwcSumOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNdhwcSumOp>(rewriter, genericOp, dilations, strides);
+  return failure();
 }
 
-static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
-  if (isaConv3DNcdhwFcdhwOp(genericOp))
-    return "linalg.conv_3d_ncdhw_fcdhw";
-  if (isaConv3DNdhwcDhwcfOp(genericOp))
-    return "linalg.conv_3d_ndhwc_dhwcf";
-  if (isaConv3DNdhwcDhwcfQOp(genericOp))
-    return "linalg.conv_3d_ndhwc_dhwcf_q";
-  if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp))
-    return "linalg.depthwise_conv_3d_ndhwc_dhwcm";
-  return "";
+static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConv3DNcdhwFcdhwOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv3DNcdhwFcdhwOp>(rewriter, genericOp, dilations, strides);
+  if (isaConv3DNdhwcDhwcfOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv3DNdhwcDhwcfOp>(rewriter, genericOp, dilations, strides);
+  if (isaConv3DNdhwcDhwcfQOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::Conv3DNdhwcDhwcfQOp>(rewriter, genericOp, dilations, strides);
+  if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(rewriter, genericOp, dilations, strides);
+  return failure();
 }
 
-static std::string inferConvolutionKind(GenericOp genericOp) {
+// Converts linalg.generic to named linalg.*conv* where possible.
+static FailureOr<LinalgOp> inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
   SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
   unsigned totalIterators = iteratorTypes.size();
   switch(totalIterators) {
     case 2:
-      return inferBasedOnRank2ConvIteratorTypes(genericOp);
+      return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
     case 4:
-      return inferBasedOnRank4ConvIteratorTypes(genericOp);
+      return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
     case 5:
-      return inferBasedOnRank5ConvIteratorTypes(genericOp);
+      return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
     case 6:
-      return inferBasedOnRank6ConvIteratorTypes(genericOp);
+      return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
     case 7:
-      return inferBasedOnRank7ConvIteratorTypes(genericOp);
+      return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
     case 8:
-      return inferBasedOnRank8ConvIteratorTypes(genericOp);
+      return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
     case 9:
-      return inferBasedOnRank9ConvIteratorTypes(genericOp);
-  }
-  return "";
-}
-
-// Converts linalg.generic to named linalg.*conv* where possible.
-static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
-                                                        GenericOp genericOp) {
-  std::string convKind = inferConvolutionKind(genericOp);
-  if (convKind == "") return failure();
-  SmallVector<Value> inputs = genericOp.getDpsInputs();
-  ValueRange outputs = genericOp.getDpsInits();
-  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
-  SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
-                                      ? TypeRange(ValueRange(outputs))
-                                      : TypeRange{};
-  LinalgOp namedOp;
-  if (convKind == "linalg.conv_1d") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv1DOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_1d_nwc_wcf") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv1DNwcWcfOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_1d_ncw_fcw") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv1DNcwFcwOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.depthwise_conv_1d_ncw_cw") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNcwCwOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.depthwise_conv_1d_nwc_wc") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNwcWcOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.depthwise_conv_1d_nwc_wcm") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNwcWcmOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_2d") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_2d_nhwc_fhwc") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcFhwcOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_2d_nhwc_hwcf") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcHwcfOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_2d_nchw_fchw") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNchwFchwOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_2d_nhwc_fhwc_q") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcFhwcQOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_2d_nchw_fchw_q") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNchwFchwQOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_2d_ngchw_fgchw") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwFgchwOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_2d_ngchw_gfchw") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwGfchwOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_2d_ngchw_gfchw_q") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwGfchwQOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwgcGfhwcOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_2d_nhwc_hwcf_q") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcHwcfQOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc_q") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwgcGfhwcQOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.depthwise_conv_2d_nchw_chw") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNchwChwOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm_q") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc_q") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcQOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_3d") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_3d_ncdhw_fcdhw") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNcdhwFcdhwOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_3d_ndhwc_dhwcf") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNdhwcDhwcfOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.conv_3d_ndhwc_dhwcf_q") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNdhwcDhwcfQOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwcm") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.depthwise_conv_3d_ncdhw_cdhw") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwc") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.pooling_nchw_max") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNchwMaxOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.pooling_nchw_sum") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNchwSumOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.pooling_nhwc_max") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.pooling_nhwc_min") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMinOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.pooling_nhwc_sum") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcSumOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.pooling_nhwc_max_unsigned") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.pooling_nhwc_min_unsigned") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMinUnsignedOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.pooling_ncw_max") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNcwMaxOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.pooling_ncw_sum") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNcwSumOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.pooling_nwc_max") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNwcMaxOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.pooling_nwc_min") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNwcMinOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.pooling_nwc_sum") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNwcSumOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.pooling_ndhwc_max") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNdhwcMaxOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.pooling_ndhwc_min") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNdhwcMinOp>(genericOp, resultTypes, inputs, outputs);
-  } else if (convKind == "linalg.pooling_ndhwc_sum") {
-    namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNdhwcSumOp>(genericOp, resultTypes, inputs, outputs);
+      return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
   }
-  return namedOp;
-
   return failure();
 }
 
@@ -566,7 +485,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
 
   // Convolution - e.g. *conv*
   if (isaConvolutionOpInterface(genericOp)) {
-    return specializeLinalgConvolutions(rewriter, genericOp);
+    return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
   }
   return failure();
 }
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index b239f62a7049d..548f43f83b0ed 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -319,7 +319,8 @@ static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_
   return false;
 }
 
-static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim) {
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim,
+  int64_t& dilation, int64_t& stride) {
   unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
   AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
   auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
@@ -327,7 +328,6 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, un
     return false;
 
   AffineExpr dim0, dim1;
-  // TODO(Abhishek-Varma): Use this information in specialize.cpp.
   int64_t c0, c1;
 
   if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
@@ -335,7 +335,15 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, un
     // Pattern matched with dims and constants extracted.
     AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
     AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
-    return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr));
+    if (dim0 == fExpr && dim1 == oExpr) {
+      dilation = c0;
+      stride = c1;
+      return true;
+    } else if (dim1 == fExpr && dim0 == oExpr) {
+      dilation = c1;
+      stride = c0;
+      return true;
+    }
   }
   return false;
 }
@@ -354,6 +362,16 @@ static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef<int64_t>
   return true;
 }
 
+static bool updateConvDilationsAndStrides(SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides, ArrayRef<int64_t> tempDilations, ArrayRef<int64_t> tempStrides) {
+  if (!(dilations && strides))
+    return true;
+  for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) {
+    dilations->push_back(dilation);
+    strides->push_back(stride);
+  }
+  return true;
+}
+
 bool isaConv1DOp(LinalgOp op) {
   if (isa<linalg::Conv1DOp>(op)) return true;
 
@@ -365,10 +383,12 @@ bool isaConv1DOp(LinalgOp op) {
   // #map = affine_map<(d0, d1) -> (d0 + d1)>
   // #map1 = affine_map<(d0, d1) -> (d1)>
   // #map2 = affine_map<(d0, d1) -> (d0)>
-  return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0);
+  SmallVector<int64_t> tempDilations(1,1);
+  SmallVector<int64_t> tempStrides(1,1);
+  return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]);
 }
 
-bool isaConv1DNwcWcfOp(LinalgOp op) {
+bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv1DNwcWcfOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -377,16 +397,20 @@ bool isaConv1DNwcWcfOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(1,1);
+  SmallVector<int64_t> tempStrides(1,1);
   // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaConv1DNcwFcwOp(LinalgOp op) {
+bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv1DNcwFcwOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -395,16 +419,20 @@ bool isaConv1DNcwFcwOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(1,1);
+  SmallVector<int64_t> tempStrides(1,1);
   // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv1DNcwCwOp(LinalgOp op) {
+bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::DepthwiseConv1DNcwCwOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -413,16 +441,21 @@ bool isaDepthwiseConv1DNcwCwOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(1,1);
+  SmallVector<int64_t> tempStrides(1,1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2));
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv1DNwcWcOp(LinalgOp op) {
+// -------------------
+bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::DepthwiseConv1DNwcWcOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -431,16 +464,20 @@ bool isaDepthwiseConv1DNwcWcOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(1,1);
+  SmallVector<int64_t> tempStrides(1,1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1));
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) {
+bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -449,14 +486,18 @@ bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,4})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(1,1);
+  SmallVector<int64_t> tempStrides(1,1);
   // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
 bool isaConv2DOp(LinalgOp op) {
@@ -467,14 +508,16 @@ bool isaConv2DOp(LinalgOp op) {
   ArrayAttr indexingMaps = op.getIndexingMaps();
   if (!verifyConvIndexingMapSizes(indexingMaps, {2,2,2})) return false;
   
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-  return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1));
+  return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[1], tempStrides[1]));
 }
 
-bool isaConv2DNhwcFhwcOp(LinalgOp op) {
+bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv2DNhwcFhwcOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -483,17 +526,21 @@ bool isaConv2DNhwcFhwcOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaConv2DNhwcHwcfOp(LinalgOp op) {
+bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv2DNhwcHwcfOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -502,17 +549,21 @@ bool isaConv2DNhwcHwcfOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaConv2DNchwFchwOp(LinalgOp op) {
+bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv2DNchwFchwOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -521,17 +572,21 @@ bool isaConv2DNchwFchwOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaConv2DNhwcFhwcQOp(LinalgOp op) {
+bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv2DNhwcFhwcQOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -540,18 +595,22 @@ bool isaConv2DNhwcFhwcQOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaConv2DNchwFchwQOp(LinalgOp op) {
+bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv2DNchwFchwQOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -560,18 +619,22 @@ bool isaConv2DNchwFchwQOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaConv2DNgchwFgchwOp(LinalgOp op) {
+bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv2DNgchwFgchwOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -580,19 +643,23 @@ bool isaConv2DNgchwFgchwOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaConv2DNgchwGfchwOp(LinalgOp op) {
+bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv2DNgchwGfchwOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -601,20 +668,23 @@ bool isaConv2DNgchwGfchwOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  return (indexingMaps.size() == 3 &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaConv2DNhwcHwcfQOp(LinalgOp op) {
+bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv2DNhwcHwcfQOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -623,18 +693,22 @@ bool isaConv2DNhwcHwcfQOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaConv2DNhwgcGfhwcQOp(LinalgOp op) {
+bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -643,20 +717,24 @@ bool isaConv2DNhwgcGfhwcQOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaConv2DNgchwGfchwQOp(LinalgOp op) {
+bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv2DNgchwGfchwQOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -665,20 +743,24 @@ bool isaConv2DNgchwGfchwQOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaConv2DNhwgcGfhwcOp(LinalgOp op) {
+bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv2DNhwgcGfhwcOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -687,19 +769,23 @@ bool isaConv2DNhwgcGfhwcOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv2DNchwChwOp(LinalgOp op) {
+bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::DepthwiseConv2DNchwChwOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -708,17 +794,21 @@ bool isaDepthwiseConv2DNchwChwOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3));
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[1], tempStrides[1]));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op) {
+bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -727,17 +817,21 @@ bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op) {
+bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -746,18 +840,22 @@ bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,5})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
+bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -766,19 +864,23 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,5})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op) {
+bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -787,15 +889,19 @@ bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,0,0,4})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
 bool isaConv3DOp(LinalgOp op) {
@@ -806,15 +912,17 @@ bool isaConv3DOp(LinalgOp op) {
   ArrayAttr indexingMaps = op.getIndexingMaps();
   if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
   
+  SmallVector<int64_t> tempDilations(3,1);
+  SmallVector<int64_t> tempStrides(3,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
-  return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2));
+  return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[1], tempStrides[1]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[2], tempStrides[2]));
 }
 
-bool isaConv3DNcdhwFcdhwOp(LinalgOp op) {
+bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv3DNcdhwFcdhwOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -823,18 +931,22 @@ bool isaConv3DNcdhwFcdhwOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(3,1);
+  SmallVector<int64_t> tempStrides(3,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[2], tempStrides[2]) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaConv3DNdhwcDhwcfOp(LinalgOp op) {
+bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv3DNdhwcDhwcfOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -843,18 +955,22 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(3,1);
+  SmallVector<int64_t> tempStrides(3,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaConv3DNdhwcDhwcfQOp(LinalgOp op) {
+bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -863,19 +979,23 @@ bool isaConv3DNdhwcDhwcfQOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+  
+  SmallVector<int64_t> tempDilations(3,1);
+  SmallVector<int64_t> tempStrides(3,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) {
+bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -884,19 +1004,23 @@ bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,6})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(3,1);
+  SmallVector<int64_t> tempStrides(3,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
       matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) {
+bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::DepthwiseConv3DNcdhwCdhwOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -905,18 +1029,22 @@ bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(3,1);
+  SmallVector<int64_t> tempStrides(3,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4));
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4, tempDilations[2], tempStrides[2]));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) {
+bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::DepthwiseConv3DNdhwcDhwcOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -925,18 +1053,22 @@ bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) {
   if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false;
   
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(3,1);
+  SmallVector<int64_t> tempStrides(3,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaPoolingNchwMaxOp(LinalgOp op) {
+bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::PoolingNchwMaxOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -948,17 +1080,21 @@ bool isaPoolingNchwMaxOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
       bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaPoolingNchwSumOp(LinalgOp op) {
+bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::PoolingNchwSumOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -970,17 +1106,21 @@ bool isaPoolingNchwSumOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
       bodyMatcherForSumPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaPoolingNhwcMaxOp(LinalgOp op) {
+bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::PoolingNhwcMaxOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -992,17 +1132,21 @@ bool isaPoolingNhwcMaxOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
       bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaPoolingNhwcMinOp(LinalgOp op) {
+bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::PoolingNhwcMinOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -1014,17 +1158,21 @@ bool isaPoolingNhwcMinOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
       bodyMatcherForMinSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaPoolingNhwcSumOp(LinalgOp op) {
+bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::PoolingNhwcSumOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -1036,17 +1184,21 @@ bool isaPoolingNhwcSumOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
       bodyMatcherForSumPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op) {
+bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -1058,17 +1210,21 @@ bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
       bodyMatcherForMaxUnsignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaPoolingNhwcMinUnsignedOp(LinalgOp op) {
+bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::PoolingNhwcMinUnsignedOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -1080,17 +1236,21 @@ bool isaPoolingNhwcMinUnsignedOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(2,1);
+  SmallVector<int64_t> tempStrides(2,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
       bodyMatcherForMinUnsignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaPoolingNcwMaxOp(LinalgOp op) {
+bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::PoolingNcwMaxOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -1102,16 +1262,20 @@ bool isaPoolingNcwMaxOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
+ 
+  SmallVector<int64_t> tempDilations(1,1);
+  SmallVector<int64_t> tempStrides(1,1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
       bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaPoolingNcwSumOp(LinalgOp op) {
+bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::PoolingNcwSumOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -1123,16 +1287,20 @@ bool isaPoolingNcwSumOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
+ 
+  SmallVector<int64_t> tempDilations(1,1);
+  SmallVector<int64_t> tempStrides(1,1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
       bodyMatcherForSumPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaPoolingNwcMaxOp(LinalgOp op) {
+bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::PoolingNwcMaxOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -1144,16 +1312,20 @@ bool isaPoolingNwcMaxOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(1,1);
+  SmallVector<int64_t> tempStrides(1,1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
       bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaPoolingNwcMinOp(LinalgOp op) {
+bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::PoolingNwcMinOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -1165,16 +1337,20 @@ bool isaPoolingNwcMinOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(1,1);
+  SmallVector<int64_t> tempStrides(1,1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
       bodyMatcherForMinSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaPoolingNwcSumOp(LinalgOp op) {
+bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::PoolingNwcSumOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -1186,16 +1362,20 @@ bool isaPoolingNwcSumOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(1,1);
+  SmallVector<int64_t> tempStrides(1,1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
       bodyMatcherForSumPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaPoolingNdhwcMaxOp(LinalgOp op) {
+bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::PoolingNdhwcMaxOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -1207,18 +1387,22 @@ bool isaPoolingNdhwcMaxOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(3,1);
+  SmallVector<int64_t> tempStrides(3,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
       bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaPoolingNdhwcMinOp(LinalgOp op) {
+bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::PoolingNdhwcMinOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -1230,18 +1414,22 @@ bool isaPoolingNdhwcMinOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(3,1);
+  SmallVector<int64_t> tempStrides(3,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
       bodyMatcherForMinSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
-bool isaPoolingNdhwcSumOp(LinalgOp op) {
+bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
   if (isa<linalg::PoolingNdhwcSumOp>(op)) return true;
 
   if (!isaConvolutionOpInterface(op)) return false;
@@ -1253,15 +1441,19 @@ bool isaPoolingNdhwcSumOp(LinalgOp op) {
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
+  
+  SmallVector<int64_t> tempDilations(3,1);
+  SmallVector<int64_t> tempStrides(3,1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
+      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
       matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
       bodyMatcherForSumPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
 }
 
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,

>From 1a9417d3623074ba40ee95c7ef94c5c14b678d87 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 14 Oct 2025 06:19:23 -0500
Subject: [PATCH 16/18] Add lit test and clean up

---
 .../Dialect/Linalg/Transforms/Specialize.cpp  |   5 +-
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       |   9 +
 ...oundtrip-linalg-convolution-named-ops.mlir | 615 ++++++++++++++++++
 3 files changed, 627 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/roundtrip-linalg-convolution-named-ops.mlir

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 94dfbcc15d055..12eb17ef0a435 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,6 +237,7 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
 }
 
+/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy` with `dilations` and `strides`.
 template <typename ConvOpTy>
 static FailureOr<LinalgOp> specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
   SmallVector<Value> inputs = genericOp.getDpsInputs();
@@ -380,7 +381,7 @@ static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank9ConvIteratorTypes(Rewri
   return failure();
 }
 
-// Converts linalg.generic to named linalg.*conv* where possible.
+// Converts linalg.generic to named linalg.*conv/pooling* where possible. To improve the search speed, the convolution ops have been segregated based on the rank of iterator types array.
 static FailureOr<LinalgOp> inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
   SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
   unsigned totalIterators = iteratorTypes.size();
@@ -483,7 +484,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
     return specializeLinalgContractions(rewriter, genericOp);
   }
 
-  // Convolution - e.g. *conv*
+  // Convolution - e.g. *conv/pooling*
   if (isaConvolutionOpInterface(genericOp)) {
     return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
   }
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 548f43f83b0ed..0d4e8aa5e6382 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -319,6 +319,11 @@ static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_
   return false;
 }
 
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+///   indexingMaps[0].getResult(iDim) == 
+///         indexingMaps[1].getResult(fDim) * <CST_1> +
+///         indexingMaps[n-1].getResult(oDim) * <CST_2>
+///  where, CST_1 and CST_2 can be any constant.
 static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim,
   int64_t& dilation, int64_t& stride) {
   unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
@@ -348,10 +353,13 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, un
   return false;
 }
 
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+///   indexingMaps[aIndex].getResult(aDim) == indexingMaps[bIndex].getResult(bDim)
 static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) {
   return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim);
 }
 
+/// Give an array of AffineMaps, verify each map to be of the corresponding `expectedSize`.
 static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef<int64_t> expectedSizes) {
   if (indexingMaps.size() != expectedSizes.size()) return false;
 
@@ -362,6 +370,7 @@ static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef<int64_t>
   return true;
 }
 
+/// Utility to update `dilations` and `strides` by copy the corresponding data from `tempDilations` and `tempStrides`.
 static bool updateConvDilationsAndStrides(SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides, ArrayRef<int64_t> tempDilations, ArrayRef<int64_t> tempStrides) {
   if (!(dilations && strides))
     return true;
diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-convolution-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-convolution-named-ops.mlir
new file mode 100644
index 0000000000000..8cd57044e613f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-convolution-named-ops.mlir
@@ -0,0 +1,615 @@
+// The following test examples of linalg convolution named ops lowered to linalg.generic and then
+// lifted back up to named op.
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+func.func @conv_1d_nwc_wcf(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
+  linalg.conv_1d_nwc_wcf {dilations = dense<3> : tensor<1xi64>,
+                                       strides = dense<2> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs (%output: memref<?x?x?xf32>)
+  return
+}
+//      CHECK: @conv_1d_nwc_wcf
+//      CHECK:   linalg.conv_1d_nwc_wcf
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_1d_ncw_fcw(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
+  linalg.conv_1d_ncw_fcw {dilations = dense<3> : tensor<1xi64>,
+                                       strides = dense<2> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs (%output: memref<?x?x?xf32>)
+  return
+}
+//      CHECK: @conv_1d_ncw_fcw
+//      CHECK:   linalg.conv_1d_ncw_fcw
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_1d(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>) -> () {
+  linalg.conv_1d ins(%in, %filter : memref<?xf32>, memref<?xf32>)
+                outs(%out : memref<?xf32>)
+  return
+}
+//      CHECK: @conv_1d
+//      CHECK:   linalg.conv_1d
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_1d_ncw_cw(%input: memref<?x?x?xf32>, %filter: memref<?x?xf32>, %output: memref<?x?x?xf32>) {
+  linalg.depthwise_conv_1d_ncw_cw {dilations = dense<3> : tensor<1xi64>,
+                                       strides = dense<2> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?xf32>)
+    outs (%output: memref<?x?x?xf32>)
+  return
+}
+//      CHECK: @depthwise_conv_1d_ncw_cw
+//      CHECK:   linalg.depthwise_conv_1d_ncw_cw
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_1d_nwc_wc(%input: memref<?x?x?xf32>, %filter: memref<?x?xf32>, %output: memref<?x?x?xf32>) {
+  linalg.depthwise_conv_1d_nwc_wc {dilations = dense<3> : tensor<1xi64>,
+                                       strides = dense<2> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?xf32>)
+    outs (%output: memref<?x?x?xf32>)
+  return
+}
+//      CHECK: @depthwise_conv_1d_nwc_wc
+//      CHECK:   linalg.depthwise_conv_1d_nwc_wc
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_1d_nwc_wcm(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?x?xf32>) {
+  linalg.depthwise_conv_1d_nwc_wcm {dilations = dense<1> : tensor<1xi64>,
+                                       strides = dense<1> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs (%output: memref<?x?x?x?xf32>)
+  return
+}
+//      CHECK: @depthwise_conv_1d_nwc_wcm
+//      CHECK:   linalg.depthwise_conv_1d_nwc_wcm
+// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
+  linalg.conv_2d ins(%in, %filter : memref<?x?xf32>, memref<?x?xf32>)
+                outs(%out: memref<?x?xf32>)
+  return
+}
+//      CHECK: @conv_2d
+//      CHECK:   linalg.conv_2d
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nchw_fchw(%arg0 : tensor<?x?x?x?xf32>,
+    %arg1 : tensor<?x?x?x?xf32>,  %arg2 : tensor<?x?x?x?xf32>) ->
+    (tensor<4x8x12x16xf32>, tensor<?x?x?x?xf32>) {
+  %0 = linalg.conv_2d_nchw_fchw {dilations = dense<[2,4]> : tensor<2xi64>, strides = dense<[3,5]> : tensor<2xi64>}
+  ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+      outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  %1 = tensor.cast %0 : tensor<?x?x?x?xf32> to tensor<4x8x12x16xf32>
+  return %1, %0 : tensor<4x8x12x16xf32>, tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nchw_fchw
+//      CHECK:   linalg.conv_2d_nchw_fchw
+// CHECK-SAME:      dilations = dense<[2, 4]> : tensor<2xi64>, strides = dense<[3, 5]> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nchw_fchw_q(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter, %inputzp, %filterzp: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, i32, i32)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nchw_fchw_q
+//      CHECK:   linalg.conv_2d_nchw_fchw_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_ngchw_fgchw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_ngchw_fgchw
+//      CHECK:   linalg.conv_2d_ngchw_fgchw
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_ngchw_gfchw(%input: memref<?x?x?x?x?xi8>, %filter: memref<?x?x?x?x?xi8>, %output: memref<?x?x?x?x?xi32>) {
+  linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>,
+                                       strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: memref<?x?x?x?x?xi8>, memref<?x?x?x?x?xi8>)
+    outs (%output: memref<?x?x?x?x?xi32>)
+  return
+}
+//      CHECK: @conv_2d_ngchw_gfchw
+//      CHECK:   linalg.conv_2d_ngchw_gfchw
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_ngchw_gfchw_q(%input: memref<?x?x?x?x?xi8>, %filter: memref<?x?x?x?x?xi8>, %inputzp: i32, %filterzp: i32, %output: memref<?x?x?x?x?xi32>) {
+  linalg.conv_2d_ngchw_gfchw_q {dilations = dense<1> : tensor<2xi64>,
+                                       strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter, %inputzp, %filterzp: memref<?x?x?x?x?xi8>, memref<?x?x?x?x?xi8>, i32, i32)
+    outs (%output: memref<?x?x?x?x?xi32>)
+  return
+}
+//      CHECK: @conv_2d_ngchw_gfchw_q
+//      CHECK:   linalg.conv_2d_ngchw_gfchw_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwc_hwcf_q(%input: memref<?x?x?x?xf32>, %filter: memref<?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %output: memref<?x?x?x?xf32>) {
+  linalg.conv_2d_nhwc_hwcf_q {
+    dilations = dense<1> : tensor<2xi64>,
+    strides = dense<1> : tensor<2xi64>
+  } ins(%input, %filter, %inputzp, %filterzp : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, i32, i32) outs(%output : memref<?x?x?x?xf32>)
+  return
+}
+//      CHECK: @conv_2d_nhwc_hwcf_q
+//      CHECK:   linalg.conv_2d_nhwc_hwcf_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwgc_gfhwc_q(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %output: memref<?x?x?x?x?xf32>) {
+  linalg.conv_2d_nhwgc_gfhwc_q {
+    dilations = dense<1> : tensor<2xi64>,
+    strides = dense<1> : tensor<2xi64>
+  } ins(%input, %filter, %inputzp, %filterzp : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>, i32, i32) outs(%output : memref<?x?x?x?x?xf32>)
+  return
+}
+//      CHECK: @conv_2d_nhwgc_gfhwc_q
+//      CHECK:   linalg.conv_2d_nhwgc_gfhwc_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwc_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?xi8>, %inputzp: i32, %filterzp: i32, %output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>{
+  %res = linalg.depthwise_conv_2d_nhwc_hwc_q {
+    dilations = dense<1> : tensor<2xi64>,
+    strides = dense<1> : tensor<2xi64>
+  } ins(%input, %filter, %inputzp, %filterzp : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%output : tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+  return %res : tensor<?x?x?x?xi32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwc_q
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwc_fhwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                                 strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nhwc_fhwc
+//      CHECK:   linalg.conv_2d_nhwc_fhwc
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwc_fhwc_q(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter, %inputzp, %filterzp: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, i32, i32)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nhwc_fhwc_q
+//      CHECK:   linalg.conv_2d_nhwc_fhwc_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwc_hwcf(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nhwc_hwcf
+//      CHECK:   linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwgc_gfhwc(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
+  linalg.conv_2d_nhwgc_gfhwc {dilations = dense<1> : tensor<2xi64>,
+                                         strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+    outs (%output: memref<?x?x?x?x?xf32>)
+  return
+}
+//      CHECK: @conv_2d_nhwgc_gfhwc
+//      CHECK:   linalg.conv_2d_nhwgc_gfhwc
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nchw_chw(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_2d_nchw_chw
+//      CHECK:   linalg.depthwise_conv_2d_nchw_chw
+// CHECK-SAME:      dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[4, 5]> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>,
+                                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwc
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwcm(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>,
+                                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwcm
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwcm
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwcm_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x?xi8>, %arg2: tensor<?x?x?x?x?xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x?xi32> {
+  %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?x?xi32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwcm_q
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwcm_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_3d(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
+  linalg.conv_3d ins(%in, %filter : memref<?x?x?xf32>, memref<?x?x?xf32>)
+                outs(%out : memref<?x?x?xf32>)
+  return
+}
+//      CHECK: @conv_3d
+//      CHECK:   linalg.conv_3d
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_3d_ncdhw_fcdhw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.conv_3d_ncdhw_fcdhw {dilations = dense<1> : tensor<3xi64>,
+                                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @conv_3d_ncdhw_fcdhw
+//      CHECK:   linalg.conv_3d_ncdhw_fcdhw
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_3d_ndhwc_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>,
+                                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @conv_3d_ndhwc_dhwcf
+//      CHECK:   linalg.conv_3d_ndhwc_dhwcf
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_3d_ndhwc_dhwcf_q(%input: tensor<?x?x?x?x?xi8>, %filter: tensor<?x?x?x?x?xi8>, %inputzp: i32, %filterzp: i32, %init: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32> {
+  %0 = linalg.conv_3d_ndhwc_dhwcf_q {dilations = dense<1> : tensor<3xi64>,
+                                                strides = dense<1> : tensor<3xi64>}
+    ins(%input, %filter, %inputzp, %filterzp : tensor<?x?x?x?x?xi8>, tensor<?x?x?x?x?xi8>, i32, i32)
+    outs (%init: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?x?xi32>
+}
+//      CHECK: @conv_3d_ndhwc_dhwcf_q
+//      CHECK:   linalg.conv_3d_ndhwc_dhwcf_q
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_3d_ncdhw_cdhw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_3d_ncdhw_cdhw {dilations = dense<1> : tensor<3xi64>,
+                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_3d_ncdhw_cdhw
+//      CHECK:   linalg.depthwise_conv_3d_ncdhw_cdhw
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_3d_ndhwc_dhwc(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_3d_ndhwc_dhwc {dilations = dense<1> : tensor<3xi64>,
+                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_3d_ndhwc_dhwc
+//      CHECK:   linalg.depthwise_conv_3d_ndhwc_dhwc
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>,
+                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_3d_ndhwc_dhwcm
+//      CHECK:   linalg.depthwise_conv_3d_ndhwc_dhwcm
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nchw_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nchw_max
+//      CHECK:   linalg.pooling_nchw_max
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nchw_sum(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nchw_sum
+//      CHECK:   linalg.pooling_nchw_sum
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_ncw_max(%input: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>, %filter: tensor<?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.pooling_ncw_max {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+    ins(%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+    outs(%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @pooling_ncw_max
+//      CHECK:   linalg.pooling_ncw_max
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_ncw_sum(%input: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>, %filter: tensor<?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.pooling_ncw_sum {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+    ins(%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+    outs(%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @pooling_ncw_sum
+//      CHECK:   linalg.pooling_ncw_sum
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_ndhwc_max(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.pooling_ndhwc_max {dilations = dense<1> : tensor<3xi64>,
+                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @pooling_ndhwc_max
+//      CHECK:   linalg.pooling_ndhwc_max
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_ndhwc_min(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.pooling_ndhwc_min {dilations = dense<1> : tensor<3xi64>,
+                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @pooling_ndhwc_min
+//      CHECK:   linalg.pooling_ndhwc_min
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_ndhwc_sum(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.pooling_ndhwc_sum {dilations = dense<1> : tensor<3xi64>,
+                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @pooling_ndhwc_sum
+//      CHECK:   linalg.pooling_ndhwc_sum
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nhwc_max
+//      CHECK:   linalg.pooling_nhwc_max
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_min(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nhwc_min
+//      CHECK:   linalg.pooling_nhwc_min
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_sum(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nhwc_sum
+//      CHECK:   linalg.pooling_nhwc_sum
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x?x?x?xi32>, %filter: tensor<?x?xi32>, %init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
+  %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xi32>, tensor<?x?xi32>)
+    outs (%init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?xi32>
+}
+//      CHECK: @pooling_nhwc_max_unsigned
+//      CHECK:   linalg.pooling_nhwc_max_unsigned
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_min_unsigned(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nhwc_min_unsigned
+//      CHECK:   linalg.pooling_nhwc_min
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nwc_max(%input: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>, %filter: tensor<?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+    outs(%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @pooling_nwc_max
+//      CHECK:   linalg.pooling_nwc_max
+// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nwc_min(%input: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>, %filter: tensor<?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.pooling_nwc_min {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+    ins(%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+    outs(%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @pooling_nwc_min
+//      CHECK:   linalg.pooling_nwc_min
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nwc_sum(%input: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>, %filter: tensor<?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.pooling_nwc_sum {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+    ins(%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+    outs(%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @pooling_nwc_sum
+//      CHECK:   linalg.pooling_nwc_sum
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic

>From bba3921a36b934ee96f4bc4794e47dd329e9995b Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 14 Oct 2025 07:22:14 -0500
Subject: [PATCH 17/18] Format code

---
 .../include/mlir/Dialect/Linalg/Utils/Utils.h |  153 +-
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  221 +-
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 2066 ++++++++++-------
 3 files changed, 1527 insertions(+), 913 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 44ebc101d7c37..0f39098ca9946 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -115,50 +115,119 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
 //===----------------------------------------------------------------------===//
 
 bool isaConv1DOp(LinalgOp op);
-bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+                       SmallVector<int64_t> *strides = nullptr);
+bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+                       SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv1DNcwCwOp(LinalgOp op,
+                               SmallVector<int64_t> *dilations = nullptr,
+                               SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
+                               SmallVector<int64_t> *dilations = nullptr,
+                               SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op,
+                                SmallVector<int64_t> *dilations = nullptr,
+                                SmallVector<int64_t> *strides = nullptr);
 bool isaConv2DOp(LinalgOp op);
-bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+                         SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+                         SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+                         SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNhwcFhwcQOp(LinalgOp op,
+                          SmallVector<int64_t> *dilations = nullptr,
+                          SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNchwFchwQOp(LinalgOp op,
+                          SmallVector<int64_t> *dilations = nullptr,
+                          SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNgchwFgchwOp(LinalgOp op,
+                           SmallVector<int64_t> *dilations = nullptr,
+                           SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNgchwGfchwOp(LinalgOp op,
+                           SmallVector<int64_t> *dilations = nullptr,
+                           SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNhwcHwcfQOp(LinalgOp op,
+                          SmallVector<int64_t> *dilations = nullptr,
+                          SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNhwgcGfhwcQOp(LinalgOp op,
+                            SmallVector<int64_t> *dilations = nullptr,
+                            SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNgchwGfchwQOp(LinalgOp op,
+                            SmallVector<int64_t> *dilations = nullptr,
+                            SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNhwgcGfhwcOp(LinalgOp op,
+                           SmallVector<int64_t> *dilations = nullptr,
+                           SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
+                                 SmallVector<int64_t> *dilations = nullptr,
+                                 SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op,
+                                 SmallVector<int64_t> *dilations = nullptr,
+                                 SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op,
+                                  SmallVector<int64_t> *dilations = nullptr,
+                                  SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op,
+                                  SmallVector<int64_t> *dilations = nullptr,
+                                  SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op,
+                                   SmallVector<int64_t> *dilations = nullptr,
+                                   SmallVector<int64_t> *strides = nullptr);
 bool isaConv3DOp(LinalgOp op);
-bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv3DNcdhwFcdhwOp(LinalgOp op,
+                           SmallVector<int64_t> *dilations = nullptr,
+                           SmallVector<int64_t> *strides = nullptr);
+bool isaConv3DNdhwcDhwcfOp(LinalgOp op,
+                           SmallVector<int64_t> *dilations = nullptr,
+                           SmallVector<int64_t> *strides = nullptr);
+bool isaConv3DNdhwcDhwcfQOp(LinalgOp op,
+                            SmallVector<int64_t> *dilations = nullptr,
+                            SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
+                                    SmallVector<int64_t> *dilations = nullptr,
+                                    SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op,
+                                   SmallVector<int64_t> *dilations = nullptr,
+                                   SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op,
+                                   SmallVector<int64_t> *dilations = nullptr,
+                                   SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+                         SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+                         SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+                         SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+                         SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+                         SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
+                                 SmallVector<int64_t> *dilations = nullptr,
+                                 SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
+                                 SmallVector<int64_t> *dilations = nullptr,
+                                 SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+                        SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+                        SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+                        SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+                        SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+                        SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNdhwcMaxOp(LinalgOp op,
+                          SmallVector<int64_t> *dilations = nullptr,
+                          SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNdhwcMinOp(LinalgOp op,
+                          SmallVector<int64_t> *dilations = nullptr,
+                          SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNdhwcSumOp(LinalgOp op,
+                          SmallVector<int64_t> *dilations = nullptr,
+                          SmallVector<int64_t> *strides = nullptr);
 
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 12eb17ef0a435..e08705b90e7b0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,9 +237,12 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
 }
 
-/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy` with `dilations` and `strides`.
+/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy`
+/// with `dilations` and `strides`.
 template <typename ConvOpTy>
-static FailureOr<LinalgOp> specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
+static FailureOr<LinalgOp>
+specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
+                   ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
   SmallVector<Value> inputs = genericOp.getDpsInputs();
   ValueRange outputs = genericOp.getDpsInits();
   SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
@@ -247,159 +250,227 @@ static FailureOr<LinalgOp> specializeToConvOp(RewriterBase &rewriter, GenericOp
                                       ? TypeRange(ValueRange(outputs))
                                       : TypeRange{};
   LinalgOp namedOp;
-  if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> || std::is_same_v<ConvOpTy, linalg::Conv2DOp> || std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
-    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes, inputs, outputs);
+  if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
+                std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
+                std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
+                                                    inputs, outputs);
   } else {
     Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
     Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
-    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
+        genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
   }
   return namedOp;
 }
 
-static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
-  if (isaConv1DOp(genericOp)) return specializeToConvOp<linalg::Conv1DOp>(rewriter, genericOp, dilations, strides);
+  if (isaConv1DOp(genericOp))
+    return specializeToConvOp<linalg::Conv1DOp>(rewriter, genericOp, dilations,
+                                                strides);
   return failure();
 }
 
-static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
   if (isaDepthwiseConv1DNcwCwOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::DepthwiseConv1DNcwCwOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::DepthwiseConv1DNcwCwOp>(
+        rewriter, genericOp, dilations, strides);
   if (isaDepthwiseConv1DNwcWcOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
+        rewriter, genericOp, dilations, strides);
   if (isaConv2DOp(genericOp))
-    return specializeToConvOp<linalg::Conv2DOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv2DOp>(rewriter, genericOp, dilations,
+                                                strides);
   if (isaPoolingNcwMaxOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::PoolingNcwMaxOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::PoolingNcwMaxOp>(rewriter, genericOp,
+                                                       dilations, strides);
   if (isaPoolingNcwSumOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::PoolingNcwSumOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::PoolingNcwSumOp>(rewriter, genericOp,
+                                                       dilations, strides);
   if (isaPoolingNwcMaxOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::PoolingNwcMaxOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::PoolingNwcMaxOp>(rewriter, genericOp,
+                                                       dilations, strides);
   if (isaPoolingNwcMinOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::PoolingNwcMinOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::PoolingNwcMinOp>(rewriter, genericOp,
+                                                       dilations, strides);
   if (isaPoolingNwcSumOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::PoolingNwcSumOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::PoolingNwcSumOp>(rewriter, genericOp,
+                                                       dilations, strides);
   return failure();
 }
 
-static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
   if (isaDepthwiseConv1DNwcWcmOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::DepthwiseConv1DNwcWcmOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::DepthwiseConv1DNwcWcmOp>(
+        rewriter, genericOp, dilations, strides);
   if (isaConv1DNwcWcfOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv1DNwcWcfOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv1DNwcWcfOp>(rewriter, genericOp,
+                                                      dilations, strides);
   if (isaConv1DNcwFcwOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv1DNcwFcwOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv1DNcwFcwOp>(rewriter, genericOp,
+                                                      dilations, strides);
   return failure();
 }
 
-static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
   if (isaDepthwiseConv2DNchwChwOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
+        rewriter, genericOp, dilations, strides);
   if (isaDepthwiseConv2DNhwcHwcOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcOp>(
+        rewriter, genericOp, dilations, strides);
   if (isaDepthwiseConv2DNhwcHwcQOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcQOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcQOp>(
+        rewriter, genericOp, dilations, strides);
   if (isaConv3DOp(genericOp))
-    return specializeToConvOp<linalg::Conv3DOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv3DOp>(rewriter, genericOp, dilations,
+                                                strides);
   if (isaPoolingNchwMaxOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::PoolingNchwMaxOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::PoolingNchwMaxOp>(rewriter, genericOp,
+                                                        dilations, strides);
   if (isaPoolingNchwSumOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::PoolingNchwSumOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::PoolingNchwSumOp>(rewriter, genericOp,
+                                                        dilations, strides);
   if (isaPoolingNhwcMaxOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
+                                                        dilations, strides);
   if (isaPoolingNhwcMinOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
+                                                        dilations, strides);
   if (isaPoolingNhwcSumOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
+                                                        dilations, strides);
   if (isaPoolingNhwcMaxUnsignedOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
+        rewriter, genericOp, dilations, strides);
   if (isaPoolingNhwcMinUnsignedOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
+        rewriter, genericOp, dilations, strides);
   return failure();
 }
 
-static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
   if (isaConv2DNhwcFhwcOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv2DNhwcFhwcOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv2DNhwcFhwcOp>(rewriter, genericOp,
+                                                        dilations, strides);
   if (isaConv2DNhwcHwcfOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv2DNhwcHwcfOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv2DNhwcHwcfOp>(rewriter, genericOp,
+                                                        dilations, strides);
   if (isaConv2DNchwFchwOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv2DNchwFchwOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv2DNchwFchwOp>(rewriter, genericOp,
+                                                        dilations, strides);
   if (isaConv2DNhwcFhwcQOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv2DNhwcFhwcQOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv2DNhwcFhwcQOp>(rewriter, genericOp,
+                                                         dilations, strides);
   if (isaConv2DNchwFchwQOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv2DNchwFchwQOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv2DNchwFchwQOp>(rewriter, genericOp,
+                                                         dilations, strides);
   if (isaConv2DNhwcHwcfQOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv2DNhwcHwcfQOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv2DNhwcHwcfQOp>(rewriter, genericOp,
+                                                         dilations, strides);
   if (isaDepthwiseConv2DNhwcHwcmOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmOp>(
+        rewriter, genericOp, dilations, strides);
   if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(
+        rewriter, genericOp, dilations, strides);
   return failure();
 }
 
-static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
   if (isaConv2DNgchwFgchwOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv2DNgchwFgchwOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv2DNgchwFgchwOp>(rewriter, genericOp,
+                                                          dilations, strides);
   if (isaConv2DNgchwGfchwOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv2DNgchwGfchwOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv2DNgchwGfchwOp>(rewriter, genericOp,
+                                                          dilations, strides);
   if (isaConv2DNgchwGfchwQOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv2DNgchwGfchwQOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv2DNgchwGfchwQOp>(rewriter, genericOp,
+                                                           dilations, strides);
   if (isaConv2DNhwgcGfhwcOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv2DNhwgcGfhwcOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv2DNhwgcGfhwcOp>(rewriter, genericOp,
+                                                          dilations, strides);
   if (isaConv2DNhwgcGfhwcQOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv2DNhwgcGfhwcQOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv2DNhwgcGfhwcQOp>(rewriter, genericOp,
+                                                           dilations, strides);
   if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(
+        rewriter, genericOp, dilations, strides);
   if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(
+        rewriter, genericOp, dilations, strides);
   if (isaPoolingNdhwcMaxOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::PoolingNdhwcMaxOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::PoolingNdhwcMaxOp>(rewriter, genericOp,
+                                                         dilations, strides);
   if (isaPoolingNdhwcMinOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::PoolingNdhwcMinOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::PoolingNdhwcMinOp>(rewriter, genericOp,
+                                                         dilations, strides);
   if (isaPoolingNdhwcSumOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::PoolingNdhwcSumOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::PoolingNdhwcSumOp>(rewriter, genericOp,
+                                                         dilations, strides);
   return failure();
 }
 
-static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
   if (isaConv3DNcdhwFcdhwOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv3DNcdhwFcdhwOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv3DNcdhwFcdhwOp>(rewriter, genericOp,
+                                                          dilations, strides);
   if (isaConv3DNdhwcDhwcfOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv3DNdhwcDhwcfOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv3DNdhwcDhwcfOp>(rewriter, genericOp,
+                                                          dilations, strides);
   if (isaConv3DNdhwcDhwcfQOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::Conv3DNdhwcDhwcfQOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::Conv3DNdhwcDhwcfQOp>(rewriter, genericOp,
+                                                           dilations, strides);
   if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(rewriter, genericOp, dilations, strides);
+    return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+        rewriter, genericOp, dilations, strides);
   return failure();
 }
 
-// Converts linalg.generic to named linalg.*conv/pooling* where possible. To improve the search speed, the convolution ops have been segregated based on the rank of iterator types array.
-static FailureOr<LinalgOp> inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
-  SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
+// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
+// improve the search speed, the convolution ops have been segregated based on
+// the rank of iterator types array.
+static FailureOr<LinalgOp>
+inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
+  SmallVector<utils::IteratorType> iteratorTypes =
+      genericOp.getIteratorTypesArray();
   unsigned totalIterators = iteratorTypes.size();
-  switch(totalIterators) {
-    case 2:
-      return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
-    case 4:
-      return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
-    case 5:
-      return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
-    case 6:
-      return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
-    case 7:
-      return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
-    case 8:
-      return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
-    case 9:
-      return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
+  switch (totalIterators) {
+  case 2:
+    return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
+  case 4:
+    return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
+  case 5:
+    return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
+  case 6:
+    return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
+  case 7:
+    return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
+  case 8:
+    return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
+  case 9:
+    return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
   }
   return failure();
 }
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 0d4e8aa5e6382..8ea5e7a10e17e 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -249,28 +249,34 @@ template <typename... OpTypes>
 static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
   Operation *defOp = yieldVal.getDefiningOp();
   // if (!defOp) return false;
-  if (!(isa_and_present<OpTypes>(defOp) || ...)) return false;
+  if (!(isa_and_present<OpTypes>(defOp) || ...))
+    return false;
 
-  BlockArgument lhsArg =  dyn_cast<BlockArgument>(defOp->getOperand(0));
-  BlockArgument rhsArg =  dyn_cast<BlockArgument>(defOp->getOperand(1));
-  if (!lhsArg || !rhsArg) return false;
+  BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
+  BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
+  if (!lhsArg || !rhsArg)
+    return false;
   return true;
 }
 
 static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
-  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, body);
+  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal,
+                                                                  body);
 }
 
 static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
-  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, body);
+  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
+                                                                  body);
 }
 
 static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
-  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, body);
+  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal,
+                                                                  body);
 }
 
 static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
-  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, body);
+  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
+                                                                  body);
 }
 
 static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
@@ -288,7 +294,8 @@ static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
 // Check if `expr` is either:
 // - a dimension expr alone (implying *1), or
 // - a multiplication of dimension expr by constant.
-static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_t &constantValue) {
+static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim,
+                                        int64_t &constantValue) {
   if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
     dim = dExpr;
     constantValue = 1;
@@ -320,12 +327,13 @@ static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_
 }
 
 /// Given an array of AffineMaps `indexingMaps` verify the following :-
-///   indexingMaps[0].getResult(iDim) == 
+///   indexingMaps[0].getResult(iDim) ==
 ///         indexingMaps[1].getResult(fDim) * <CST_1> +
 ///         indexingMaps[n-1].getResult(oDim) * <CST_2>
 ///  where, CST_1 and CST_2 can be any constant.
-static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim,
-  int64_t& dilation, int64_t& stride) {
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
+                                       unsigned fDim, unsigned oDim,
+                                       int64_t &dilation, int64_t &stride) {
   unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
   AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
   auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
@@ -354,24 +362,37 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, un
 }
 
 /// Given an array of AffineMaps `indexingMaps` verify the following :-
-///   indexingMaps[aIndex].getResult(aDim) == indexingMaps[bIndex].getResult(bDim)
-static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) {
-  return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim);
-}
-
-/// Give an array of AffineMaps, verify each map to be of the corresponding `expectedSize`.
-static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef<int64_t> expectedSizes) {
-  if (indexingMaps.size() != expectedSizes.size()) return false;
+///   indexingMaps[aIndex].getResult(aDim) ==
+///   indexingMaps[bIndex].getResult(bDim)
+static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex,
+                                    unsigned aDim, unsigned bIndex,
+                                    unsigned bDim) {
+  return getAffineMapDim(indexingMaps, aIndex, aDim) ==
+         getAffineMapDim(indexingMaps, bIndex, bDim);
+}
+
+/// Give an array of AffineMaps, verify each map to be of the corresponding
+/// `expectedSize`.
+static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps,
+                                       ArrayRef<int64_t> expectedSizes) {
+  if (indexingMaps.size() != expectedSizes.size())
+    return false;
 
-  for (auto [indexingMap, expectedSize] : llvm::zip_equal(indexingMaps, expectedSizes)) {
+  for (auto [indexingMap, expectedSize] :
+       llvm::zip_equal(indexingMaps, expectedSizes)) {
     auto affineMap = cast<AffineMapAttr>(indexingMap).getValue();
-    if (affineMap.getNumResults() != expectedSize) return false;
+    if (affineMap.getNumResults() != expectedSize)
+      return false;
   }
   return true;
 }
 
-/// Utility to update `dilations` and `strides` by copy the corresponding data from `tempDilations` and `tempStrides`.
-static bool updateConvDilationsAndStrides(SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides, ArrayRef<int64_t> tempDilations, ArrayRef<int64_t> tempStrides) {
+/// Utility to update `dilations` and `strides` by copy the corresponding data
+/// from `tempDilations` and `tempStrides`.
+static bool updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
+                                          SmallVector<int64_t> *strides,
+                                          ArrayRef<int64_t> tempDilations,
+                                          ArrayRef<int64_t> tempStrides) {
   if (!(dilations && strides))
     return true;
   for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) {
@@ -382,1087 +403,1540 @@ static bool updateConvDilationsAndStrides(SmallVector<int64_t>* dilations, Small
 }
 
 bool isaConv1DOp(LinalgOp op) {
-  if (isa<linalg::Conv1DOp>(op)) return true;
+  if (isa<linalg::Conv1DOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {1,1,1})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {1, 1, 1}))
+    return false;
+
   // #map = affine_map<(d0, d1) -> (d0 + d1)>
   // #map1 = affine_map<(d0, d1) -> (d1)>
   // #map2 = affine_map<(d0, d1) -> (d0)>
-  SmallVector<int64_t> tempDilations(1,1);
-  SmallVector<int64_t> tempStrides(1,1);
-  return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]);
+  SmallVector<int64_t> tempDilations(1, 1);
+  SmallVector<int64_t> tempStrides(1, 1);
+  return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
+                                    /*oDim=*/0, tempDilations[0],
+                                    tempStrides[0]);
 }
 
-bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv1DNwcWcfOp>(op)) return true;
+bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                       SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv1DNwcWcfOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 3}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(1,1);
-  SmallVector<int64_t> tempStrides(1,1);
+
+  SmallVector<int64_t> tempDilations(1, 1);
+  SmallVector<int64_t> tempStrides(1, 1);
   // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv1DNcwFcwOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                       SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv1DNcwFcwOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 3}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(1,1);
-  SmallVector<int64_t> tempStrides(1,1);
+
+  SmallVector<int64_t> tempDilations(1, 1);
+  SmallVector<int64_t> tempStrides(1, 1);
   // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::DepthwiseConv1DNcwCwOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                               SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(1,1);
-  SmallVector<int64_t> tempStrides(1,1);
+
+  SmallVector<int64_t> tempDilations(1, 1);
+  SmallVector<int64_t> tempStrides(1, 1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[0],
+                                  tempStrides[0]));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
 }
 
 // -------------------
-bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::DepthwiseConv1DNwcWcOp>(op)) return true;
+bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                               SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(1,1);
-  SmallVector<int64_t> tempStrides(1,1);
+
+  SmallVector<int64_t> tempDilations(1, 1);
+  SmallVector<int64_t> tempStrides(1, 1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 4}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(1,1);
-  SmallVector<int64_t> tempStrides(1,1);
+
+  SmallVector<int64_t> tempDilations(1, 1);
+  SmallVector<int64_t> tempStrides(1, 1);
   // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
 }
 
 bool isaConv2DOp(LinalgOp op) {
-  if (isa<linalg::Conv2DOp>(op)) return true;
+  if (isa<linalg::Conv2DOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {2,2,2})) return false;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
+  if (!verifyConvIndexingMapSizes(indexingMaps, {2, 2, 2}))
+    return false;
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-  return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[1], tempStrides[1]));
+  return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
+                                     /*oDim=*/0, tempDilations[0],
+                                     tempStrides[0]) &&
+          matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+                                     /*oDim=*/1, tempDilations[1],
+                                     tempStrides[1]));
 }
 
-bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv2DNhwcFhwcOp>(op)) return true;
+bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                         SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwcFhwcOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 4}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
 
-bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv2DNhwcHwcfOp>(op)) return true;
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5,
+  // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                         SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwcHwcfOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 4}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
 
-bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv2DNchwFchwOp>(op)) return true;
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5,
+  // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                         SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNchwFchwOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 4}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
 
-bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv2DNhwcFhwcQOp>(op)) return true;
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 +
+  // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                          SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwcFhwcQOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 4}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5,
+  // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv2DNchwFchwQOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                          SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNchwFchwQOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 4}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 +
+  // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv2DNgchwFgchwOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                           SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNgchwFgchwOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv2DNgchwGfchwOp>(op)) return true;
-
-  if (!isaConvolutionOpInterface(op)) return false;
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6,
+  // d4 + d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1,
+  // d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0,
+  // d1, d2, d3, d4)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
+                                  /*oDim=*/4, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                           SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNgchwGfchwOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv2DNhwcHwcfQOp>(op)) return true;
-
-  if (!isaConvolutionOpInterface(op)) return false;
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6,
+  // d4 + d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2,
+  // d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0,
+  // d1, d2, d3, d4)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
+                                  /*oDim=*/4, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                          SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwcHwcfQOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 4}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5,
+  // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                            SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 0, 0, 5}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
-  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
 
-bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv2DNgchwGfchwQOp>(op)) return true;
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 +
+  // d6, d3, d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3,
+  // d4, d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                            SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNgchwGfchwQOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 0, 0, 5}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
-  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv2DNhwgcGfhwcOp>(op)) return true;
-
-  if (!isaConvolutionOpInterface(op)) return false;
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6,
+  // d4 + d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2,
+  // d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
+  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3,
+  // d4)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
+                                  /*oDim=*/4, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                           SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwgcGfhwcOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::DepthwiseConv2DNchwChwOp>(op)) return true;
-
-  if (!isaConvolutionOpInterface(op)) return false;
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 +
+  // d6, d3, d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3,
+  // d4, d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) ->
+  // (d0, d1, d2, d3, d4)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                 SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[1], tempStrides[1]));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, tempDilations[1],
+                                  tempStrides[1]));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                 SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                  SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,5})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 5}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
 
-bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op)) return true;
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6,
+  // d3)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                   SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,5})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 5}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6,
+  // d3)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                  SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,0,0,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 0, 0, 4}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
   // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
 }
 
 bool isaConv3DOp(LinalgOp op) {
-  if (isa<linalg::Conv3DOp>(op)) return true;
+  if (isa<linalg::Conv3DOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
-  
-  SmallVector<int64_t> tempDilations(3,1);
-  SmallVector<int64_t> tempStrides(3,1);
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 3}))
+    return false;
+
+  SmallVector<int64_t> tempDilations(3, 1);
+  SmallVector<int64_t> tempStrides(3, 1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
-  return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[1], tempStrides[1]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[2], tempStrides[2]));
-}
-
-bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv3DNcdhwFcdhwOp>(op)) return true;
+  return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
+                                     /*oDim=*/0, tempDilations[0],
+                                     tempStrides[0]) &&
+          matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+                                     /*oDim=*/1, tempDilations[1],
+                                     tempStrides[1]) &&
+          matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                     /*oDim=*/2, tempDilations[2],
+                                     tempStrides[2]));
+}
+
+bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                           SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv3DNcdhwFcdhwOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(3,1);
-  SmallVector<int64_t> tempStrides(3,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[2], tempStrides[2]) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv3DNdhwcDhwcfOp>(op)) return true;
-
-  if (!isaConvolutionOpInterface(op)) return false;
+
+  SmallVector<int64_t> tempDilations(3, 1);
+  SmallVector<int64_t> tempStrides(3, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6,
+  // d3 + d7, d4 + d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+  // -> (d1, d5, d6, d7, d8)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6,
+  // d7, d8) -> (d0, d1, d2, d3, d4)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
+                                  /*oDim=*/4, tempDilations[2],
+                                  tempStrides[2]) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                           SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv3DNdhwcDhwcfOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(3,1);
-  SmallVector<int64_t> tempStrides(3,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op)) return true;
-
-  if (!isaConvolutionOpInterface(op)) return false;
+
+  SmallVector<int64_t> tempDilations(3, 1);
+  SmallVector<int64_t> tempStrides(3, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2
+  // + d6, d3 + d7, d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+  // -> (d5, d6, d7, d8, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6,
+  // d7, d8) -> (d0, d1, d2, d3, d4)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, tempDilations[2],
+                                  tempStrides[2]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                            SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 0, 0, 5}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-  
-  SmallVector<int64_t> tempDilations(3,1);
-  SmallVector<int64_t> tempStrides(3,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> ()>
-  // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op)) return true;
-
-  if (!isaConvolutionOpInterface(op)) return false;
+
+  SmallVector<int64_t> tempDilations(3, 1);
+  SmallVector<int64_t> tempStrides(3, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2
+  // + d6, d3 + d7, d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+  // -> (d5, d6, d7, d8, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6,
+  // d7, d8) -> ()> #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) ->
+  // (d0, d1, d2, d3, d4)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, tempDilations[2],
+                                  tempStrides[2]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
+                                    SmallVector<int64_t> *dilations,
+                                    SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,6})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(3,1);
-  SmallVector<int64_t> tempStrides(3,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
-      matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::DepthwiseConv3DNcdhwCdhwOp>(op)) return true;
-
-  if (!isaConvolutionOpInterface(op)) return false;
+
+  SmallVector<int64_t> tempDilations(3, 1);
+  SmallVector<int64_t> tempStrides(3, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2
+  // + d6, d3 + d7, d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+  // -> (d5, d6, d7, d8, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6,
+  // d7, d8) -> (d0, d1, d2, d3, d8, d4)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, tempDilations[2],
+                                  tempStrides[2]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                   SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv3DNcdhwCdhwOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 4, 5}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(3,1);
-  SmallVector<int64_t> tempStrides(3,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4, tempDilations[2], tempStrides[2]));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::DepthwiseConv3DNdhwcDhwcOp>(op)) return true;
-
-  if (!isaConvolutionOpInterface(op)) return false;
+
+  SmallVector<int64_t> tempDilations(3, 1);
+  SmallVector<int64_t> tempStrides(3, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2
+  // + d5, d3 + d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7,
+  // d4, d5, d6)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0,
+  // d7, d1, d2, d3)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3,
+                                  /*oDim=*/4, tempDilations[2],
+                                  tempStrides[2]));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                   SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv3DNdhwcDhwcOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 4, 5}))
+    return false;
+
   unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(3,1);
-  SmallVector<int64_t> tempStrides(3,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::PoolingNchwMaxOp>(op)) return true;
-
-  if (!isaConvolutionOpInterface(op)) return false;
+
+  SmallVector<int64_t> tempDilations(3, 1);
+  SmallVector<int64_t> tempStrides(3, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 +
+  // d5, d3 + d6, d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) ->
+  // (d4, d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) ->
+  // (d0, d1, d2, d3, d7)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, tempDilations[2],
+                                  tempStrides[2]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                         SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNchwMaxOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+    return false;
+
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
-      bodyMatcherForMaxSignedPoolOps(yieldVal, body));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::PoolingNchwSumOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0,
+                                  /*oDim=*/2, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1,
+                                  /*oDim=*/3, tempDilations[1],
+                                  tempStrides[1]) &&
+       bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                         SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNchwSumOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+    return false;
+
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
-      bodyMatcherForSumPoolOps(yieldVal, body));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::PoolingNhwcMaxOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0,
+                                  /*oDim=*/2, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1,
+                                  /*oDim=*/3, tempDilations[1],
+                                  tempStrides[1]) &&
+       bodyMatcherForSumPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                         SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNhwcMaxOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+    return false;
+
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
-      bodyMatcherForMaxSignedPoolOps(yieldVal, body));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::PoolingNhwcMinOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+       bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                         SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNhwcMinOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+    return false;
+
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
-      bodyMatcherForMinSignedPoolOps(yieldVal, body));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::PoolingNhwcSumOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+       bodyMatcherForMinSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                         SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNhwcSumOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+    return false;
+
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
-      bodyMatcherForSumPoolOps(yieldVal, body));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+       bodyMatcherForSumPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                 SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+    return false;
+
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
-      bodyMatcherForMaxUnsignedPoolOps(yieldVal, body));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::PoolingNhwcMinUnsignedOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+       bodyMatcherForMaxUnsignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                 SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+    return false;
+
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(2,1);
-  SmallVector<int64_t> tempStrides(2,1);
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
   // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
   // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
   // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
-      bodyMatcherForMinUnsignedPoolOps(yieldVal, body));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::PoolingNcwMaxOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+       bodyMatcherForMinUnsignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                        SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNcwMaxOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3}))
+    return false;
+
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
- 
-  SmallVector<int64_t> tempDilations(1,1);
-  SmallVector<int64_t> tempStrides(1,1);
+
+  SmallVector<int64_t> tempDilations(1, 1);
+  SmallVector<int64_t> tempStrides(1, 1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
-      bodyMatcherForMaxSignedPoolOps(yieldVal, body));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::PoolingNcwSumOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0,
+                                  /*oDim=*/2, tempDilations[0],
+                                  tempStrides[0]) &&
+       bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                        SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNcwSumOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3}))
+    return false;
+
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
- 
-  SmallVector<int64_t> tempDilations(1,1);
-  SmallVector<int64_t> tempStrides(1,1);
+
+  SmallVector<int64_t> tempDilations(1, 1);
+  SmallVector<int64_t> tempStrides(1, 1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
-      bodyMatcherForSumPoolOps(yieldVal, body));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::PoolingNwcMaxOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0,
+                                  /*oDim=*/2, tempDilations[0],
+                                  tempStrides[0]) &&
+       bodyMatcherForSumPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                        SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNwcMaxOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3}))
+    return false;
+
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
 
-  SmallVector<int64_t> tempDilations(1,1);
-  SmallVector<int64_t> tempStrides(1,1);
+  SmallVector<int64_t> tempDilations(1, 1);
+  SmallVector<int64_t> tempStrides(1, 1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
-      bodyMatcherForMaxSignedPoolOps(yieldVal, body));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::PoolingNwcMinOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+       bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                        SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNwcMinOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3}))
+    return false;
+
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
 
-  SmallVector<int64_t> tempDilations(1,1);
-  SmallVector<int64_t> tempStrides(1,1);
+  SmallVector<int64_t> tempDilations(1, 1);
+  SmallVector<int64_t> tempStrides(1, 1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
-      bodyMatcherForMinSignedPoolOps(yieldVal, body));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::PoolingNwcSumOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+       bodyMatcherForMinSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                        SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNwcSumOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3}))
+    return false;
+
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
 
-  SmallVector<int64_t> tempDilations(1,1);
-  SmallVector<int64_t> tempStrides(1,1);
+  SmallVector<int64_t> tempDilations(1, 1);
+  SmallVector<int64_t> tempStrides(1, 1);
   // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
   // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
   // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
-      bodyMatcherForSumPoolOps(yieldVal, body));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::PoolingNdhwcMaxOp>(op)) return true;
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+       bodyMatcherForSumPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                          SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNdhwcMaxOp>(op))
+    return true;
 
-  if (!isaConvolutionOpInterface(op)) return false;
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 3, 5}))
+    return false;
+
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(3,1);
-  SmallVector<int64_t> tempStrides(3,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
-      bodyMatcherForMaxSignedPoolOps(yieldVal, body));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::PoolingNdhwcMinOp>(op)) return true;
-
-  if (!isaConvolutionOpInterface(op)) return false;
+
+  SmallVector<int64_t> tempDilations(3, 1);
+  SmallVector<int64_t> tempStrides(3, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 +
+  // d6, d3 + d7, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) ->
+  // (d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0,
+  // d1, d2, d3, d4)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, tempDilations[2],
+                                  tempStrides[2]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+       bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                          SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNdhwcMinOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 3, 5}))
+    return false;
+
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(3,1);
-  SmallVector<int64_t> tempStrides(3,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
-      bodyMatcherForMinSignedPoolOps(yieldVal, body));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
-  if (isa<linalg::PoolingNdhwcSumOp>(op)) return true;
-
-  if (!isaConvolutionOpInterface(op)) return false;
+
+  SmallVector<int64_t> tempDilations(3, 1);
+  SmallVector<int64_t> tempStrides(3, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 +
+  // d6, d3 + d7, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) ->
+  // (d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0,
+  // d1, d2, d3, d4)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, tempDilations[2],
+                                  tempStrides[2]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+       bodyMatcherForMinSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                          SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNdhwcSumOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
-  if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false;
-  
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 3, 5}))
+    return false;
+
   Block *body = op.getBlock();
   auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
   Value yieldVal = yieldOp.getOperand(0);
   unsigned iIndex = 0, oIndex = 2;
-  
-  SmallVector<int64_t> tempDilations(3,1);
-  SmallVector<int64_t> tempStrides(3,1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
-  bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
-      matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
-      matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
-      bodyMatcherForSumPoolOps(yieldVal, body));
-  return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
+
+  SmallVector<int64_t> tempDilations(3, 1);
+  SmallVector<int64_t> tempStrides(3, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 +
+  // d6, d3 + d7, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) ->
+  // (d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0,
+  // d1, d2, d3, d4)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, tempDilations[2],
+                                  tempStrides[2]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+       bodyMatcherForSumPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
 }
 
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,

>From 3852dc4ffeac76056c4e31ac74397ade5f3dc228 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 15 Oct 2025 03:26:58 -0500
Subject: [PATCH 18/18] Export just a single API

---
 .../include/mlir/Dialect/Linalg/Utils/Utils.h | 116 +----
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 132 +++--
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 452 ++++++++++++++----
 3 files changed, 450 insertions(+), 250 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 0f39098ca9946..771d753a8bddb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -111,123 +111,13 @@ std::optional<SmallVector<ReassociationIndices>>
 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
 
 //===----------------------------------------------------------------------===//
-// Convolution matcher utilities
+// Convolution matcher utility
 //===----------------------------------------------------------------------===//
 
-bool isaConv1DOp(LinalgOp op);
-bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
-                       SmallVector<int64_t> *strides = nullptr);
-bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
-                       SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv1DNcwCwOp(LinalgOp op,
-                               SmallVector<int64_t> *dilations = nullptr,
-                               SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
-                               SmallVector<int64_t> *dilations = nullptr,
-                               SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op,
-                                SmallVector<int64_t> *dilations = nullptr,
-                                SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DOp(LinalgOp op);
-bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
-                         SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
-                         SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
-                         SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNhwcFhwcQOp(LinalgOp op,
-                          SmallVector<int64_t> *dilations = nullptr,
-                          SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNchwFchwQOp(LinalgOp op,
-                          SmallVector<int64_t> *dilations = nullptr,
-                          SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNgchwFgchwOp(LinalgOp op,
-                           SmallVector<int64_t> *dilations = nullptr,
-                           SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNgchwGfchwOp(LinalgOp op,
-                           SmallVector<int64_t> *dilations = nullptr,
-                           SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNhwcHwcfQOp(LinalgOp op,
-                          SmallVector<int64_t> *dilations = nullptr,
-                          SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNhwgcGfhwcQOp(LinalgOp op,
+template <typename ConvOpTy>
+bool isaConvolutionOpOfType(LinalgOp op,
                             SmallVector<int64_t> *dilations = nullptr,
                             SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNgchwGfchwQOp(LinalgOp op,
-                            SmallVector<int64_t> *dilations = nullptr,
-                            SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNhwgcGfhwcOp(LinalgOp op,
-                           SmallVector<int64_t> *dilations = nullptr,
-                           SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
-                                 SmallVector<int64_t> *dilations = nullptr,
-                                 SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op,
-                                 SmallVector<int64_t> *dilations = nullptr,
-                                 SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op,
-                                  SmallVector<int64_t> *dilations = nullptr,
-                                  SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op,
-                                  SmallVector<int64_t> *dilations = nullptr,
-                                  SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op,
-                                   SmallVector<int64_t> *dilations = nullptr,
-                                   SmallVector<int64_t> *strides = nullptr);
-bool isaConv3DOp(LinalgOp op);
-bool isaConv3DNcdhwFcdhwOp(LinalgOp op,
-                           SmallVector<int64_t> *dilations = nullptr,
-                           SmallVector<int64_t> *strides = nullptr);
-bool isaConv3DNdhwcDhwcfOp(LinalgOp op,
-                           SmallVector<int64_t> *dilations = nullptr,
-                           SmallVector<int64_t> *strides = nullptr);
-bool isaConv3DNdhwcDhwcfQOp(LinalgOp op,
-                            SmallVector<int64_t> *dilations = nullptr,
-                            SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
-                                    SmallVector<int64_t> *dilations = nullptr,
-                                    SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op,
-                                   SmallVector<int64_t> *dilations = nullptr,
-                                   SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op,
-                                   SmallVector<int64_t> *dilations = nullptr,
-                                   SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
-                         SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
-                         SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
-                         SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
-                         SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
-                         SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
-                                 SmallVector<int64_t> *dilations = nullptr,
-                                 SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
-                                 SmallVector<int64_t> *dilations = nullptr,
-                                 SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
-                        SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
-                        SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
-                        SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
-                        SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
-                        SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNdhwcMaxOp(LinalgOp op,
-                          SmallVector<int64_t> *dilations = nullptr,
-                          SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNdhwcMinOp(LinalgOp op,
-                          SmallVector<int64_t> *dilations = nullptr,
-                          SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNdhwcSumOp(LinalgOp op,
-                          SmallVector<int64_t> *dilations = nullptr,
-                          SmallVector<int64_t> *strides = nullptr);
 
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index e08705b90e7b0..929904fa2c510 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -268,7 +268,7 @@ static FailureOr<LinalgOp>
 inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
                                                 GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
-  if (isaConv1DOp(genericOp))
+  if (isaConvolutionOpOfType<linalg::Conv1DOp>(genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::Conv1DOp>(rewriter, genericOp, dilations,
                                                 strides);
   return failure();
@@ -278,28 +278,35 @@ static FailureOr<LinalgOp>
 inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
                                                 GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
-  if (isaDepthwiseConv1DNcwCwOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
+          genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::DepthwiseConv1DNcwCwOp>(
         rewriter, genericOp, dilations, strides);
-  if (isaDepthwiseConv1DNwcWcOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+          genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
         rewriter, genericOp, dilations, strides);
-  if (isaConv2DOp(genericOp))
+  if (isaConvolutionOpOfType<linalg::Conv2DOp>(genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::Conv2DOp>(rewriter, genericOp, dilations,
                                                 strides);
-  if (isaPoolingNcwMaxOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::PoolingNcwMaxOp>(genericOp, &dilations,
+                                                      &strides))
     return specializeToConvOp<linalg::PoolingNcwMaxOp>(rewriter, genericOp,
                                                        dilations, strides);
-  if (isaPoolingNcwSumOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::PoolingNcwSumOp>(genericOp, &dilations,
+                                                      &strides))
     return specializeToConvOp<linalg::PoolingNcwSumOp>(rewriter, genericOp,
                                                        dilations, strides);
-  if (isaPoolingNwcMaxOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::PoolingNwcMaxOp>(genericOp, &dilations,
+                                                      &strides))
     return specializeToConvOp<linalg::PoolingNwcMaxOp>(rewriter, genericOp,
                                                        dilations, strides);
-  if (isaPoolingNwcMinOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::PoolingNwcMinOp>(genericOp, &dilations,
+                                                      &strides))
     return specializeToConvOp<linalg::PoolingNwcMinOp>(rewriter, genericOp,
                                                        dilations, strides);
-  if (isaPoolingNwcSumOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::PoolingNwcSumOp>(genericOp, &dilations,
+                                                      &strides))
     return specializeToConvOp<linalg::PoolingNwcSumOp>(rewriter, genericOp,
                                                        dilations, strides);
   return failure();
@@ -309,13 +316,16 @@ static FailureOr<LinalgOp>
 inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
                                                 GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
-  if (isaDepthwiseConv1DNwcWcmOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
+          genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::DepthwiseConv1DNwcWcmOp>(
         rewriter, genericOp, dilations, strides);
-  if (isaConv1DNwcWcfOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(genericOp, &dilations,
+                                                     &strides))
     return specializeToConvOp<linalg::Conv1DNwcWcfOp>(rewriter, genericOp,
                                                       dilations, strides);
-  if (isaConv1DNcwFcwOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(genericOp, &dilations,
+                                                     &strides))
     return specializeToConvOp<linalg::Conv1DNcwFcwOp>(rewriter, genericOp,
                                                       dilations, strides);
   return failure();
@@ -325,37 +335,47 @@ static FailureOr<LinalgOp>
 inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
                                                 GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
-  if (isaDepthwiseConv2DNchwChwOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+          genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
         rewriter, genericOp, dilations, strides);
-  if (isaDepthwiseConv2DNhwcHwcOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(
+          genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcOp>(
         rewriter, genericOp, dilations, strides);
-  if (isaDepthwiseConv2DNhwcHwcQOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
+          genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcQOp>(
         rewriter, genericOp, dilations, strides);
-  if (isaConv3DOp(genericOp))
+  if (isaConvolutionOpOfType<linalg::Conv3DOp>(genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::Conv3DOp>(rewriter, genericOp, dilations,
                                                 strides);
-  if (isaPoolingNchwMaxOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::PoolingNchwMaxOp>(genericOp, &dilations,
+                                                       &strides))
     return specializeToConvOp<linalg::PoolingNchwMaxOp>(rewriter, genericOp,
                                                         dilations, strides);
-  if (isaPoolingNchwSumOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::PoolingNchwSumOp>(genericOp, &dilations,
+                                                       &strides))
     return specializeToConvOp<linalg::PoolingNchwSumOp>(rewriter, genericOp,
                                                         dilations, strides);
-  if (isaPoolingNhwcMaxOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
+                                                       &strides))
     return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
                                                         dilations, strides);
-  if (isaPoolingNhwcMinOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
+                                                       &strides))
     return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
                                                         dilations, strides);
-  if (isaPoolingNhwcSumOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
+                                                       &strides))
     return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
                                                         dilations, strides);
-  if (isaPoolingNhwcMaxUnsignedOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+          genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
         rewriter, genericOp, dilations, strides);
-  if (isaPoolingNhwcMinUnsignedOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+          genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
         rewriter, genericOp, dilations, strides);
   return failure();
@@ -365,28 +385,36 @@ static FailureOr<LinalgOp>
 inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
                                                 GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
-  if (isaConv2DNhwcFhwcOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(genericOp, &dilations,
+                                                       &strides))
     return specializeToConvOp<linalg::Conv2DNhwcFhwcOp>(rewriter, genericOp,
                                                         dilations, strides);
-  if (isaConv2DNhwcHwcfOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(genericOp, &dilations,
+                                                       &strides))
     return specializeToConvOp<linalg::Conv2DNhwcHwcfOp>(rewriter, genericOp,
                                                         dilations, strides);
-  if (isaConv2DNchwFchwOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(genericOp, &dilations,
+                                                       &strides))
     return specializeToConvOp<linalg::Conv2DNchwFchwOp>(rewriter, genericOp,
                                                         dilations, strides);
-  if (isaConv2DNhwcFhwcQOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(genericOp, &dilations,
+                                                        &strides))
     return specializeToConvOp<linalg::Conv2DNhwcFhwcQOp>(rewriter, genericOp,
                                                          dilations, strides);
-  if (isaConv2DNchwFchwQOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(genericOp, &dilations,
+                                                        &strides))
     return specializeToConvOp<linalg::Conv2DNchwFchwQOp>(rewriter, genericOp,
                                                          dilations, strides);
-  if (isaConv2DNhwcHwcfQOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(genericOp, &dilations,
+                                                        &strides))
     return specializeToConvOp<linalg::Conv2DNhwcHwcfQOp>(rewriter, genericOp,
                                                          dilations, strides);
-  if (isaDepthwiseConv2DNhwcHwcmOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
+          genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmOp>(
         rewriter, genericOp, dilations, strides);
-  if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmQOp>(
+          genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(
         rewriter, genericOp, dilations, strides);
   return failure();
@@ -396,34 +424,44 @@ static FailureOr<LinalgOp>
 inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
                                                 GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
-  if (isaConv2DNgchwFgchwOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(genericOp, &dilations,
+                                                         &strides))
     return specializeToConvOp<linalg::Conv2DNgchwFgchwOp>(rewriter, genericOp,
                                                           dilations, strides);
-  if (isaConv2DNgchwGfchwOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(genericOp, &dilations,
+                                                         &strides))
     return specializeToConvOp<linalg::Conv2DNgchwGfchwOp>(rewriter, genericOp,
                                                           dilations, strides);
-  if (isaConv2DNgchwGfchwQOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(genericOp, &dilations,
+                                                          &strides))
     return specializeToConvOp<linalg::Conv2DNgchwGfchwQOp>(rewriter, genericOp,
                                                            dilations, strides);
-  if (isaConv2DNhwgcGfhwcOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(genericOp, &dilations,
+                                                         &strides))
     return specializeToConvOp<linalg::Conv2DNhwgcGfhwcOp>(rewriter, genericOp,
                                                           dilations, strides);
-  if (isaConv2DNhwgcGfhwcQOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(genericOp, &dilations,
+                                                          &strides))
     return specializeToConvOp<linalg::Conv2DNhwgcGfhwcQOp>(rewriter, genericOp,
                                                            dilations, strides);
-  if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNcdhwCdhwOp>(
+          genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(
         rewriter, genericOp, dilations, strides);
-  if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcOp>(
+          genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(
         rewriter, genericOp, dilations, strides);
-  if (isaPoolingNdhwcMaxOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::PoolingNdhwcMaxOp>(genericOp, &dilations,
+                                                        &strides))
     return specializeToConvOp<linalg::PoolingNdhwcMaxOp>(rewriter, genericOp,
                                                          dilations, strides);
-  if (isaPoolingNdhwcMinOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::PoolingNdhwcMinOp>(genericOp, &dilations,
+                                                        &strides))
     return specializeToConvOp<linalg::PoolingNdhwcMinOp>(rewriter, genericOp,
                                                          dilations, strides);
-  if (isaPoolingNdhwcSumOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::PoolingNdhwcSumOp>(genericOp, &dilations,
+                                                        &strides))
     return specializeToConvOp<linalg::PoolingNdhwcSumOp>(rewriter, genericOp,
                                                          dilations, strides);
   return failure();
@@ -433,16 +471,20 @@ static FailureOr<LinalgOp>
 inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
                                                 GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
-  if (isaConv3DNcdhwFcdhwOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv3DNcdhwFcdhwOp>(genericOp, &dilations,
+                                                         &strides))
     return specializeToConvOp<linalg::Conv3DNcdhwFcdhwOp>(rewriter, genericOp,
                                                           dilations, strides);
-  if (isaConv3DNdhwcDhwcfOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfOp>(genericOp, &dilations,
+                                                         &strides))
     return specializeToConvOp<linalg::Conv3DNdhwcDhwcfOp>(rewriter, genericOp,
                                                           dilations, strides);
-  if (isaConv3DNdhwcDhwcfQOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfQOp>(genericOp, &dilations,
+                                                          &strides))
     return specializeToConvOp<linalg::Conv3DNdhwcDhwcfQOp>(rewriter, genericOp,
                                                            dilations, strides);
-  if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp, &dilations, &strides))
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+          genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
         rewriter, genericOp, dilations, strides);
   return failure();
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 8ea5e7a10e17e..13235d99887a7 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,15 +240,14 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
   return iteratorType == utils::IteratorType::reduction;
 }
 
-// -------------------------------
-// ---------- CONV ---------------
-// -------------------------------
+//===----------------------------------------------------------------------===//
+// Convolution matcher utilities
+//===----------------------------------------------------------------------===//
 
 /// Utility to match block body for linalg.pool* ops.
 template <typename... OpTypes>
 static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
   Operation *defOp = yieldVal.getDefiningOp();
-  // if (!defOp) return false;
   if (!(isa_and_present<OpTypes>(defOp) || ...))
     return false;
 
@@ -402,7 +401,7 @@ static bool updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
   return true;
 }
 
-bool isaConv1DOp(LinalgOp op) {
+static bool isaConv1DOp(LinalgOp op) {
   if (isa<linalg::Conv1DOp>(op))
     return true;
 
@@ -423,8 +422,8 @@ bool isaConv1DOp(LinalgOp op) {
                                     tempStrides[0]);
 }
 
-bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                       SmallVector<int64_t> *strides) {
+static bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                              SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv1DNwcWcfOp>(op))
     return true;
 
@@ -453,8 +452,8 @@ bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                       SmallVector<int64_t> *strides) {
+static bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                              SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv1DNcwFcwOp>(op))
     return true;
 
@@ -483,8 +482,9 @@ bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                               SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv1DNcwCwOp(LinalgOp op,
+                                      SmallVector<int64_t> *dilations,
+                                      SmallVector<int64_t> *strides) {
   if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
     return true;
 
@@ -514,8 +514,9 @@ bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t> *dilations,
 }
 
 // -------------------
-bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                               SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
+                                      SmallVector<int64_t> *dilations,
+                                      SmallVector<int64_t> *strides) {
   if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
     return true;
 
@@ -544,8 +545,9 @@ bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                                SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op,
+                                       SmallVector<int64_t> *dilations,
+                                       SmallVector<int64_t> *strides) {
   if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
     return true;
 
@@ -575,7 +577,7 @@ bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaConv2DOp(LinalgOp op) {
+static bool isaConv2DOp(LinalgOp op) {
   if (isa<linalg::Conv2DOp>(op))
     return true;
 
@@ -599,8 +601,8 @@ bool isaConv2DOp(LinalgOp op) {
                                      tempStrides[1]));
 }
 
-bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                         SmallVector<int64_t> *strides) {
+static bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv2DNhwcFhwcOp>(op))
     return true;
 
@@ -632,8 +634,8 @@ bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                         SmallVector<int64_t> *strides) {
+static bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv2DNhwcHwcfOp>(op))
     return true;
 
@@ -665,8 +667,8 @@ bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                         SmallVector<int64_t> *strides) {
+static bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv2DNchwFchwOp>(op))
     return true;
 
@@ -698,8 +700,8 @@ bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                          SmallVector<int64_t> *strides) {
+static bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                 SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv2DNhwcFhwcQOp>(op))
     return true;
 
@@ -732,8 +734,8 @@ bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                          SmallVector<int64_t> *strides) {
+static bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                 SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv2DNchwFchwQOp>(op))
     return true;
 
@@ -766,8 +768,8 @@ bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                           SmallVector<int64_t> *strides) {
+static bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                  SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv2DNgchwFgchwOp>(op))
     return true;
 
@@ -802,8 +804,8 @@ bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                           SmallVector<int64_t> *strides) {
+static bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                  SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv2DNgchwGfchwOp>(op))
     return true;
 
@@ -838,8 +840,8 @@ bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                          SmallVector<int64_t> *strides) {
+static bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                 SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv2DNhwcHwcfQOp>(op))
     return true;
 
@@ -872,8 +874,8 @@ bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                            SmallVector<int64_t> *strides) {
+static bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                   SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op))
     return true;
 
@@ -908,8 +910,8 @@ bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                            SmallVector<int64_t> *strides) {
+static bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                   SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv2DNgchwGfchwQOp>(op))
     return true;
 
@@ -945,8 +947,8 @@ bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                           SmallVector<int64_t> *strides) {
+static bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                  SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv2DNhwgcGfhwcOp>(op))
     return true;
 
@@ -981,8 +983,9 @@ bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                                 SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
+                                        SmallVector<int64_t> *dilations,
+                                        SmallVector<int64_t> *strides) {
   if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
     return true;
 
@@ -1014,8 +1017,9 @@ bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                                 SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op,
+                                        SmallVector<int64_t> *dilations,
+                                        SmallVector<int64_t> *strides) {
   if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op))
     return true;
 
@@ -1047,8 +1051,9 @@ bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                                  SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op,
+                                         SmallVector<int64_t> *dilations,
+                                         SmallVector<int64_t> *strides) {
   if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op))
     return true;
 
@@ -1081,8 +1086,9 @@ bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                                   SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op,
+                                          SmallVector<int64_t> *dilations,
+                                          SmallVector<int64_t> *strides) {
   if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op))
     return true;
 
@@ -1116,8 +1122,9 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                                  SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op,
+                                         SmallVector<int64_t> *dilations,
+                                         SmallVector<int64_t> *strides) {
   if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op))
     return true;
 
@@ -1150,7 +1157,7 @@ bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaConv3DOp(LinalgOp op) {
+static bool isaConv3DOp(LinalgOp op) {
   if (isa<linalg::Conv3DOp>(op))
     return true;
 
@@ -1177,8 +1184,8 @@ bool isaConv3DOp(LinalgOp op) {
                                      tempStrides[2]));
 }
 
-bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                           SmallVector<int64_t> *strides) {
+static bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                  SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv3DNcdhwFcdhwOp>(op))
     return true;
 
@@ -1214,8 +1221,8 @@ bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                           SmallVector<int64_t> *strides) {
+static bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                  SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv3DNdhwcDhwcfOp>(op))
     return true;
 
@@ -1251,8 +1258,8 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                            SmallVector<int64_t> *strides) {
+static bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                   SmallVector<int64_t> *strides) {
   if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op))
     return true;
 
@@ -1289,9 +1296,9 @@ bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
-                                    SmallVector<int64_t> *dilations,
-                                    SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
+                                           SmallVector<int64_t> *dilations,
+                                           SmallVector<int64_t> *strides) {
   if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
     return true;
 
@@ -1328,8 +1335,9 @@ bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
                                                     tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                                   SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op,
+                                          SmallVector<int64_t> *dilations,
+                                          SmallVector<int64_t> *strides) {
   if (isa<linalg::DepthwiseConv3DNcdhwCdhwOp>(op))
     return true;
 
@@ -1365,8 +1373,9 @@ bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                                   SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op,
+                                          SmallVector<int64_t> *dilations,
+                                          SmallVector<int64_t> *strides) {
   if (isa<linalg::DepthwiseConv3DNdhwcDhwcOp>(op))
     return true;
 
@@ -1402,8 +1411,8 @@ bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                         SmallVector<int64_t> *strides) {
+static bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNchwMaxOp>(op))
     return true;
 
@@ -1438,8 +1447,8 @@ bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                         SmallVector<int64_t> *strides) {
+static bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNchwSumOp>(op))
     return true;
 
@@ -1474,8 +1483,8 @@ bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                         SmallVector<int64_t> *strides) {
+static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNhwcMaxOp>(op))
     return true;
 
@@ -1510,8 +1519,8 @@ bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                         SmallVector<int64_t> *strides) {
+static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNhwcMinOp>(op))
     return true;
 
@@ -1546,8 +1555,8 @@ bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                         SmallVector<int64_t> *strides) {
+static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNhwcSumOp>(op))
     return true;
 
@@ -1582,8 +1591,9 @@ bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                                 SmallVector<int64_t> *strides) {
+static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
+                                        SmallVector<int64_t> *dilations,
+                                        SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
     return true;
 
@@ -1618,8 +1628,9 @@ bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                                 SmallVector<int64_t> *strides) {
+static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
+                                        SmallVector<int64_t> *dilations,
+                                        SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
     return true;
 
@@ -1654,8 +1665,8 @@ bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                        SmallVector<int64_t> *strides) {
+static bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                               SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNcwMaxOp>(op))
     return true;
 
@@ -1687,8 +1698,8 @@ bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                        SmallVector<int64_t> *strides) {
+static bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                               SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNcwSumOp>(op))
     return true;
 
@@ -1720,8 +1731,8 @@ bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                        SmallVector<int64_t> *strides) {
+static bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                               SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNwcMaxOp>(op))
     return true;
 
@@ -1753,8 +1764,8 @@ bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                        SmallVector<int64_t> *strides) {
+static bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                               SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNwcMinOp>(op))
     return true;
 
@@ -1786,8 +1797,8 @@ bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                        SmallVector<int64_t> *strides) {
+static bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                               SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNwcSumOp>(op))
     return true;
 
@@ -1819,8 +1830,8 @@ bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                          SmallVector<int64_t> *strides) {
+static bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                 SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNdhwcMaxOp>(op))
     return true;
 
@@ -1859,8 +1870,8 @@ bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                          SmallVector<int64_t> *strides) {
+static bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                 SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNdhwcMinOp>(op))
     return true;
 
@@ -1899,8 +1910,8 @@ bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                          SmallVector<int64_t> *strides) {
+static bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                 SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNdhwcSumOp>(op))
     return true;
 
@@ -1939,6 +1950,263 @@ bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
+template <typename ConvOpTy>
+bool isaConvolutionOpOfType(LinalgOp op, SmallVector<int64_t> *dilations,
+                            SmallVector<int64_t> *strides) {
+  if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp>) {
+    return isaConv1DOp(op);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DNwcWcfOp>) {
+    return isaConv1DNwcWcfOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DNcwFcwOp>) {
+    return isaConv1DNcwFcwOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::DepthwiseConv1DNcwCwOp>) {
+    return isaDepthwiseConv1DNcwCwOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::DepthwiseConv1DNwcWcOp>) {
+    return isaDepthwiseConv1DNwcWcOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::DepthwiseConv1DNwcWcmOp>) {
+    return isaDepthwiseConv1DNwcWcmOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DOp>) {
+    return isaConv2DOp(op);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNhwcFhwcOp>) {
+    return isaConv2DNhwcFhwcOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNhwcHwcfOp>) {
+    return isaConv2DNhwcHwcfOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNchwFchwOp>) {
+    return isaConv2DNchwFchwOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNhwcFhwcQOp>) {
+    return isaConv2DNhwcFhwcQOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNchwFchwQOp>) {
+    return isaConv2DNchwFchwQOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNgchwFgchwOp>) {
+    return isaConv2DNgchwFgchwOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNgchwGfchwOp>) {
+    return isaConv2DNgchwGfchwOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNhwcHwcfQOp>) {
+    return isaConv2DNhwcHwcfQOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNhwgcGfhwcQOp>) {
+    return isaConv2DNhwgcGfhwcQOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNgchwGfchwQOp>) {
+    return isaConv2DNgchwGfchwQOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNhwgcGfhwcOp>) {
+    return isaConv2DNhwgcGfhwcOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::DepthwiseConv2DNchwChwOp>) {
+    return isaDepthwiseConv2DNchwChwOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::DepthwiseConv2DNhwcHwcOp>) {
+    return isaDepthwiseConv2DNhwcHwcOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::DepthwiseConv2DNhwcHwcmOp>) {
+    return isaDepthwiseConv2DNhwcHwcmOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::DepthwiseConv2DNhwcHwcQOp>) {
+    return isaDepthwiseConv2DNhwcHwcQOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::DepthwiseConv2DNhwcHwcmQOp>) {
+    return isaDepthwiseConv2DNhwcHwcmQOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+    return isaConv3DOp(op);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv3DNcdhwFcdhwOp>) {
+    return isaConv3DNcdhwFcdhwOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv3DNdhwcDhwcfOp>) {
+    return isaConv3DNdhwcDhwcfOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv3DNdhwcDhwcfQOp>) {
+    return isaConv3DNdhwcDhwcfQOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::DepthwiseConv3DNdhwcDhwcmOp>) {
+    return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::DepthwiseConv3DNcdhwCdhwOp>) {
+    return isaDepthwiseConv3DNcdhwCdhwOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::DepthwiseConv3DNdhwcDhwcOp>) {
+    return isaDepthwiseConv3DNdhwcDhwcOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNchwMaxOp>) {
+    return isaPoolingNchwMaxOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNchwSumOp>) {
+    return isaPoolingNchwSumOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcMaxOp>) {
+    return isaPoolingNhwcMaxOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcMinOp>) {
+    return isaPoolingNhwcMinOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcSumOp>) {
+    return isaPoolingNhwcSumOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::PoolingNhwcMaxUnsignedOp>) {
+    return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::PoolingNhwcMinUnsignedOp>) {
+    return isaPoolingNhwcMinUnsignedOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNcwMaxOp>) {
+    return isaPoolingNcwMaxOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNcwSumOp>) {
+    return isaPoolingNcwSumOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNwcMaxOp>) {
+    return isaPoolingNwcMaxOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNwcMinOp>) {
+    return isaPoolingNwcMinOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNwcSumOp>) {
+    return isaPoolingNwcSumOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNdhwcMaxOp>) {
+    return isaPoolingNdhwcMaxOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNdhwcMinOp>) {
+    return isaPoolingNdhwcMinOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNdhwcSumOp>) {
+    return isaPoolingNdhwcSumOp(op, dilations, strides);
+  } else {
+    return false;
+  }
+}
+
+template bool
+isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
+                                         SmallVector<int64_t> *dilations,
+                                         SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(LinalgOp op,
+                                               SmallVector<int64_t> *dilations,
+                                               SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(LinalgOp op,
+                                               SmallVector<int64_t> *dilations,
+                                               SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
+                                         SmallVector<int64_t> *dilations,
+                                         SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
+                                         SmallVector<int64_t> *dilations,
+                                         SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv3DNcdhwFcdhwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNcdhwCdhwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNchwMaxOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNchwSumOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::PoolingNcwMaxOp>(LinalgOp op,
+                                                SmallVector<int64_t> *dilations,
+                                                SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::PoolingNcwSumOp>(LinalgOp op,
+                                                SmallVector<int64_t> *dilations,
+                                                SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::PoolingNwcMaxOp>(LinalgOp op,
+                                                SmallVector<int64_t> *dilations,
+                                                SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::PoolingNwcMinOp>(LinalgOp op,
+                                                SmallVector<int64_t> *dilations,
+                                                SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::PoolingNwcSumOp>(LinalgOp op,
+                                                SmallVector<int64_t> *dilations,
+                                                SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNdhwcMaxOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNdhwcMinOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNdhwcSumOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
                             Value source, Value pad, bool nofold,
                             ValueRange typeDynDims) {



More information about the Mlir-commits mailing list