[Mlir-commits] [mlir] [Linalg] Add basic infra to add matchers for linalg.*conv*/*pool* ops (PR #163724)

Abhishek Varma llvmlistbot at llvm.org
Fri Oct 24 01:00:03 PDT 2025


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/163724

>From c6aea9193db7ad415f2fcde00e8bdcc3d98cfea4 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 16 Oct 2025 02:54:14 -0500
Subject: [PATCH 1/5] [Linalg] Add basic infra to add matchers for
 linalg.*conv*/*pool* ops

-- This commit includes the basic infra/utilities to add matchers for
   linalg.*conv*/*pool* ops - such that given a `linalg.generic` op it
   identifies which linalg.*conv*/*pool* op it is.
-- It adds a few representative linalg.*conv*/*pool* ops to demo the
   matchers' capability and does so as part of `linalg-specialize-generic-ops`
   pass.
-- The goal is directed towards addressing the aim of
   [[RFC] Op explosion in Linalg](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863)
   iteratively for `*conv*/*pooling*` ops.
-- This is part-1 of a series of PRs aimed to add matchers for Convolution ops.
-- For further details, refer to https://github.com/llvm/llvm-project/pull/163374#pullrequestreview-3341048722

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../include/mlir/Dialect/Linalg/Utils/Utils.h |   9 +
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 144 +++++
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 502 ++++++++++++++++++
 .../convolution/roundtrip-convolution.mlir    | 112 ++++
 4 files changed, 767 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 48978eb7663d5..771d753a8bddb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -110,6 +110,15 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
 std::optional<SmallVector<ReassociationIndices>>
 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
 
+//===----------------------------------------------------------------------===//
+// Convolution matcher utility
+//===----------------------------------------------------------------------===//
+
+template <typename ConvOpTy>
+bool isaConvolutionOpOfType(LinalgOp op,
+                            SmallVector<int64_t> *dilations = nullptr,
+                            SmallVector<int64_t> *strides = nullptr);
+
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 40fc0d68e358f..35861002e309e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,6 +237,145 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
 }
 
+/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy`
+/// with `dilations` and `strides`.
+template <typename ConvOpTy>
+static FailureOr<LinalgOp>
+specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
+                   ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
+  SmallVector<Value> inputs = genericOp.getDpsInputs();
+  ValueRange outputs = genericOp.getDpsInits();
+  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+  SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
+                                      ? TypeRange(ValueRange(outputs))
+                                      : TypeRange{};
+  LinalgOp namedOp;
+  if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
+                std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
+                std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
+                                                    inputs, outputs);
+  } else {
+    Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
+    Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
+    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
+        genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+  }
+  return namedOp;
+}
+
+/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
+        rewriter, genericOp, dilations, strides);
+  return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
+        rewriter, genericOp, dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
+                                                       &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
+                                                        dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
+                                                       &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
+                                                        dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
+                                                       &strides))
+    return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
+                                                        dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
+        rewriter, genericOp, dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
+        rewriter, genericOp, dilations, strides);
+  return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+        rewriter, genericOp, dilations, strides);
+  return failure();
+}
+
+// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
+// improve the search speed, the convolution ops have been segregated based on
+// the rank of iterator types array.
+static FailureOr<LinalgOp>
+inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
+  SmallVector<utils::IteratorType> iteratorTypes =
+      genericOp.getIteratorTypesArray();
+  unsigned totalIterators = iteratorTypes.size();
+  switch (totalIterators) {
+  case 2:
+    return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
+  case 4:
+    return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
+  case 5:
+    return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
+  case 6:
+    return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
+  case 7:
+    return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
+  case 8:
+    return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
+  case 9:
+    return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
+  }
+  return failure();
+}
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -316,6 +455,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
   if (isaContractionOpInterface(genericOp)) {
     return specializeLinalgContractions(rewriter, genericOp);
   }
+
+  // Convolution - e.g. *conv/pooling*
+  if (isaConvolutionOpInterface(genericOp)) {
+    return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
+  }
   return failure();
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 24d3722cf5426..c3c2819652129 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,6 +240,508 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
   return iteratorType == utils::IteratorType::reduction;
 }
 
+//===----------------------------------------------------------------------===//
+// Convolution matcher utilities
+//===----------------------------------------------------------------------===//
+
+/// Utility to match block body for linalg.pool* ops.
+template <typename... OpTypes>
+static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
+  Operation *defOp = yieldVal.getDefiningOp();
+  if (!(isa_and_present<OpTypes>(defOp) || ...))
+    return false;
+
+  BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
+  BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
+  if (!lhsArg || !rhsArg)
+    return false;
+  return true;
+}
+
+static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal,
+                                                                  body);
+}
+
+static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
+                                                                  body);
+}
+
+static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal,
+                                                                  body);
+}
+
+static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
+                                                                  body);
+}
+
+static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+}
+
+static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
+                                        uint32_t mapIndex, uint32_t dimIndex) {
+  auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
+  if (dimIndex < affineMap.getNumResults())
+    return affineMap.getResult(dimIndex);
+  return nullptr;
+}
+
+// Check if `expr` is either:
+// - a dimension expr alone (implying *1), or
+// - a multiplication of dimension expr by constant.
+static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim,
+                                        int64_t &constantValue) {
+  if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
+    dim = dExpr;
+    constantValue = 1;
+    return true;
+  }
+
+  auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+  if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
+    return false;
+
+  AffineExpr lhs = mulExpr.getLHS();
+  AffineExpr rhs = mulExpr.getRHS();
+
+  if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
+    if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
+      dim = dExpr;
+      constantValue = cst.getValue();
+      return true;
+    }
+  }
+  if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
+    if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
+      dim = dExpr;
+      constantValue = cst.getValue();
+      return true;
+    }
+  }
+  return false;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+///   indexingMaps[0].getResult(iDim) ==
+///         indexingMaps[1].getResult(fDim) * <CST_1> +
+///         indexingMaps[n-1].getResult(oDim) * <CST_2>
+///  where, CST_1 and CST_2 can be any constant.
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
+                                       unsigned fDim, unsigned oDim,
+                                       int64_t &dilation, int64_t &stride) {
+  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+  AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
+  auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
+  if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
+    return false;
+
+  AffineExpr dim0, dim1;
+  int64_t c0, c1;
+
+  if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
+      isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) {
+    // Pattern matched with dims and constants extracted.
+    AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
+    AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
+    if (dim0 == fExpr && dim1 == oExpr) {
+      dilation = c0;
+      stride = c1;
+      return true;
+    } else if (dim1 == fExpr && dim0 == oExpr) {
+      dilation = c1;
+      stride = c0;
+      return true;
+    }
+  }
+  return false;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+///   indexingMaps[aIndex].getResult(aDim) ==
+///   indexingMaps[bIndex].getResult(bDim)
+static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex,
+                                    unsigned aDim, unsigned bIndex,
+                                    unsigned bDim) {
+  return getAffineMapDim(indexingMaps, aIndex, aDim) ==
+         getAffineMapDim(indexingMaps, bIndex, bDim);
+}
+
+/// Give an array of AffineMaps, verify each map to be of the corresponding
+/// `expectedSize`.
+static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps,
+                                       ArrayRef<int64_t> expectedSizes) {
+  if (indexingMaps.size() != expectedSizes.size())
+    return false;
+
+  for (auto [indexingMap, expectedSize] :
+       llvm::zip_equal(indexingMaps, expectedSizes)) {
+    auto affineMap = cast<AffineMapAttr>(indexingMap).getValue();
+    if (affineMap.getNumResults() != expectedSize)
+      return false;
+  }
+  return true;
+}
+
+/// Utility to update `dilations` and `strides` by copy the corresponding data
+/// from `tempDilations` and `tempStrides`.
+static bool updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
+                                          SmallVector<int64_t> *strides,
+                                          ArrayRef<int64_t> tempDilations,
+                                          ArrayRef<int64_t> tempStrides) {
+  if (!(dilations && strides))
+    return true;
+  for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) {
+    dilations->push_back(dilation);
+    strides->push_back(stride);
+  }
+  return true;
+}
+
+static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
+                                      SmallVector<int64_t> *dilations,
+                                      SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3}))
+    return false;
+
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(1, 1);
+  SmallVector<int64_t> tempStrides(1, 1);
+  // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+  // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
+                                        SmallVector<int64_t> *dilations,
+                                        SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4}))
+    return false;
+
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, tempDilations[1],
+                                  tempStrides[1]));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
+                                           SmallVector<int64_t> *dilations,
+                                           SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6}))
+    return false;
+
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(3, 1);
+  SmallVector<int64_t> tempStrides(3, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+  //                  -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+  //                  -> (d5, d6, d7, d8, d4)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+  //                  -> (d0, d1, d2, d3, d8, d4)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, tempDilations[2],
+                                  tempStrides[2]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNhwcMaxOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+    return false;
+
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+       bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNhwcMinOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+    return false;
+
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+       bodyMatcherForMinSignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNhwcSumOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+    return false;
+
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+       bodyMatcherForSumPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
+                                        SmallVector<int64_t> *dilations,
+                                        SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+    return false;
+
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+       bodyMatcherForMaxUnsignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
+                                        SmallVector<int64_t> *dilations,
+                                        SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+    return false;
+
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  unsigned iIndex = 0, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+       bodyMatcherForMinUnsignedPoolOps(yieldVal, body));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+template <typename ConvOpTy>
+bool isaConvolutionOpOfType(LinalgOp op, SmallVector<int64_t> *dilations,
+                            SmallVector<int64_t> *strides) {
+  if constexpr (std::is_same_v<ConvOpTy, linalg::DepthwiseConv1DNwcWcOp>) {
+    return isaDepthwiseConv1DNwcWcOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::DepthwiseConv2DNchwChwOp>) {
+    return isaDepthwiseConv2DNchwChwOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::DepthwiseConv3DNdhwcDhwcmOp>) {
+    return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcMaxOp>) {
+    return isaPoolingNhwcMaxOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcMinOp>) {
+    return isaPoolingNhwcMinOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcSumOp>) {
+    return isaPoolingNhwcSumOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::PoolingNhwcMaxUnsignedOp>) {
+    return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides);
+  } else if constexpr (std::is_same_v<ConvOpTy,
+                                      linalg::PoolingNhwcMinUnsignedOp>) {
+    return isaPoolingNhwcMinUnsignedOp(op, dilations, strides);
+  } else {
+    return false;
+  }
+}
+
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides);
+
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
                             Value source, Value pad, bool nofold,
                             ValueRange typeDynDims) {
diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
new file mode 100644
index 0000000000000..5a18ca8519be3
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -0,0 +1,112 @@
+// The following test examples of linalg convolution named ops lowered to linalg.generic and then
+// lifted back up to named op.
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+func.func @depthwise_conv_1d_nwc_wc(%input: memref<?x?x?xf32>, %filter: memref<?x?xf32>, %output: memref<?x?x?xf32>) {
+  linalg.depthwise_conv_1d_nwc_wc {dilations = dense<3> : tensor<1xi64>,
+                                       strides = dense<2> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?xf32>)
+    outs (%output: memref<?x?x?xf32>)
+  return
+}
+//      CHECK: @depthwise_conv_1d_nwc_wc
+//      CHECK:   linalg.depthwise_conv_1d_nwc_wc
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nchw_chw(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_2d_nchw_chw
+//      CHECK:   linalg.depthwise_conv_2d_nchw_chw
+// CHECK-SAME:      dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[4, 5]> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>,
+                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_3d_ndhwc_dhwcm
+//      CHECK:   linalg.depthwise_conv_3d_ndhwc_dhwcm
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nhwc_max
+//      CHECK:   linalg.pooling_nhwc_max
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_min(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nhwc_min
+//      CHECK:   linalg.pooling_nhwc_min
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_sum(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nhwc_sum
+//      CHECK:   linalg.pooling_nhwc_sum
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x?x?x?xi32>, %filter: tensor<?x?xi32>, %init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
+  %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xi32>, tensor<?x?xi32>)
+    outs (%init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?xi32>
+}
+//      CHECK: @pooling_nhwc_max_unsigned
+//      CHECK:   linalg.pooling_nhwc_max_unsigned
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_min_unsigned(%input: tensor<?x?x?x?xi32>, %filter: tensor<?x?xi32>, %init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
+  %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xi32>, tensor<?x?xi32>)
+    outs (%init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?xi32>
+}
+//      CHECK: @pooling_nhwc_min_unsigned
+//      CHECK:   linalg.pooling_nhwc_min_unsigned
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic

>From cd1b88a9d7febdd1f933ac22254303f74643f1c2 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 17 Oct 2025 02:46:56 -0500
Subject: [PATCH 2/5] Review comment v1.0

---
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 143 +++++++++---------
 .../convolution/roundtrip-convolution.mlir    |  16 +-
 2 files changed, 87 insertions(+), 72 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index c3c2819652129..4dfec7b361eab 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -418,9 +418,9 @@ static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
 
   SmallVector<int64_t> tempDilations(1, 1);
   SmallVector<int64_t> tempStrides(1, 1);
-  // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
-  // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
-  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  // #map = affine_map<(N, W, C, w) -> (N, W + w, C)>
+  // #map1 = affine_map<(N, W, C, w) -> (w, C)>
+  // #map2 = affine_map<(N, W, C, w) -> (N, W, C)>
   bool returnVal =
       (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
        matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
@@ -449,9 +449,9 @@ static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
 
   SmallVector<int64_t> tempDilations(2, 1);
   SmallVector<int64_t> tempStrides(2, 1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
+  // #map =  affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
+  // #map1 = affine_map<(N, H, W, C, h, w) -> (C, h, w)>
+  // #map2 = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)>
   bool returnVal =
       (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
        matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
@@ -483,12 +483,12 @@ static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
 
   SmallVector<int64_t> tempDilations(3, 1);
   SmallVector<int64_t> tempStrides(3, 1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
-  //                  -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
-  //                  -> (d5, d6, d7, d8, d4)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
-  //                  -> (d0, d1, d2, d3, d8, d4)>
+  // #map  = affine_map<(N, D, H, W, CM, d, h, w, C)
+  //                    -> (N, D + d, H + h, W + w, C)>
+  // #map1 = affine_map<(N, D, H, W, CM, d, h, w, C)
+  //                    -> (d, h, w, C, CM)>
+  // #map2 = affine_map<(N, D, H, W, CM, d, h, w, C)
+  //                    -> (N, D, H, W, C, CM)>
   bool returnVal =
       (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
        matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
@@ -526,9 +526,9 @@ static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
 
   SmallVector<int64_t> tempDilations(2, 1);
   SmallVector<int64_t> tempStrides(2, 1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  // #map  = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+  // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)>
+  // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
   bool returnVal =
       (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
        matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
@@ -562,9 +562,9 @@ static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
 
   SmallVector<int64_t> tempDilations(2, 1);
   SmallVector<int64_t> tempStrides(2, 1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  // #map  = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+  // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)>
+  // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
   bool returnVal =
       (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
        matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
@@ -598,9 +598,9 @@ static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
 
   SmallVector<int64_t> tempDilations(2, 1);
   SmallVector<int64_t> tempStrides(2, 1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  // #map  = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+  // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)>
+  // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
   bool returnVal =
       (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
        matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
@@ -635,9 +635,9 @@ static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
 
   SmallVector<int64_t> tempDilations(2, 1);
   SmallVector<int64_t> tempStrides(2, 1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  // #map  = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+  // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)>
+  // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
   bool returnVal =
       (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
        matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
@@ -672,9 +672,9 @@ static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
 
   SmallVector<int64_t> tempDilations(2, 1);
   SmallVector<int64_t> tempStrides(2, 1);
-  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
-  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
-  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+  // #map  = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+  // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)>
+  // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
   bool returnVal =
       (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
        matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
@@ -689,58 +689,61 @@ static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
                                                     tempDilations, tempStrides);
 }
 
-template <typename ConvOpTy>
-bool isaConvolutionOpOfType(LinalgOp op, SmallVector<int64_t> *dilations,
-                            SmallVector<int64_t> *strides) {
-  if constexpr (std::is_same_v<ConvOpTy, linalg::DepthwiseConv1DNwcWcOp>) {
-    return isaDepthwiseConv1DNwcWcOp(op, dilations, strides);
-  } else if constexpr (std::is_same_v<ConvOpTy,
-                                      linalg::DepthwiseConv2DNchwChwOp>) {
-    return isaDepthwiseConv2DNchwChwOp(op, dilations, strides);
-  } else if constexpr (std::is_same_v<ConvOpTy,
-                                      linalg::DepthwiseConv3DNdhwcDhwcmOp>) {
-    return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides);
-  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcMaxOp>) {
-    return isaPoolingNhwcMaxOp(op, dilations, strides);
-  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcMinOp>) {
-    return isaPoolingNhwcMinOp(op, dilations, strides);
-  } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcSumOp>) {
-    return isaPoolingNhwcSumOp(op, dilations, strides);
-  } else if constexpr (std::is_same_v<ConvOpTy,
-                                      linalg::PoolingNhwcMaxUnsignedOp>) {
-    return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides);
-  } else if constexpr (std::is_same_v<ConvOpTy,
-                                      linalg::PoolingNhwcMinUnsignedOp>) {
-    return isaPoolingNhwcMinUnsignedOp(op, dilations, strides);
-  } else {
-    return false;
-  }
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  return isaDepthwiseConv1DNwcWcOp(op, dilations, strides);
 }
 
-template bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
-    LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides);
-template bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides);
-template bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+    SmallVector<int64_t> *strides) {
+  return isaDepthwiseConv2DNchwChwOp(op, dilations, strides);
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides);
-template bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
+    SmallVector<int64_t> *strides) {
+  return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides);
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides);
-template bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
+    SmallVector<int64_t> *strides) {
+  return isaPoolingNhwcMaxOp(op, dilations, strides);
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides);
-template bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
+    SmallVector<int64_t> *strides) {
+  return isaPoolingNhwcMinOp(op, dilations, strides);
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides);
-template bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+    SmallVector<int64_t> *strides) {
+  return isaPoolingNhwcSumOp(op, dilations, strides);
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides);
-template bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+    SmallVector<int64_t> *strides) {
+  return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides);
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides);
+    SmallVector<int64_t> *strides) {
+  return isaPoolingNhwcMinUnsignedOp(op, dilations, strides);
+}
 
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
                             Value source, Value pad, bool nofold,
diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
index 5a18ca8519be3..06c9a84049d81 100644
--- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -99,14 +99,26 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor<?x?x?x?xi32>, %filter: tenso
 
 // -----
 
-func.func @pooling_nhwc_min_unsigned(%input: tensor<?x?x?x?xi32>, %filter: tensor<?x?xi32>, %init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
+func.func @pooling_nhwc_min_unsigned_integer(%input: tensor<?x?x?x?xi32>, %filter: tensor<?x?xi32>, %init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
   %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x?x?x?xi32>, tensor<?x?xi32>)
     outs (%init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
   return %0 : tensor<?x?x?x?xi32>
 }
-//      CHECK: @pooling_nhwc_min_unsigned
+//      CHECK: @pooling_nhwc_min_unsigned_integer
 //      CHECK:   linalg.pooling_nhwc_min_unsigned
 // CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
 //  CHECK-NOT:   linalg.generic
+
+func.func @pooling_nhwc_min_unsigned_float(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nhwc_min_unsigned_float
+//      CHECK:   linalg.pooling_nhwc_min
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic

>From 0e9946b47c518867dae394c5221fba7d812c4803 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 22 Oct 2025 03:28:15 -0500
Subject: [PATCH 3/5] Review comment Hanhan v1.0

---
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  40 -------
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 109 +++++-------------
 2 files changed, 32 insertions(+), 117 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 35861002e309e..2bfa21d9062ee 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -264,14 +264,6 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
   return namedOp;
 }
 
-/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be
-/// added here incrementally in follow-up PRs.
-static FailureOr<LinalgOp>
-inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
-                                                GenericOp genericOp) {
-  return failure();
-}
-
 static FailureOr<LinalgOp>
 inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
                                                 GenericOp genericOp) {
@@ -283,14 +275,6 @@ inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
   return failure();
 }
 
-/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be
-/// added here incrementally in follow-up PRs.
-static FailureOr<LinalgOp>
-inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
-                                                GenericOp genericOp) {
-  return failure();
-}
-
 static FailureOr<LinalgOp>
 inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
                                                 GenericOp genericOp) {
@@ -322,22 +306,6 @@ inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
   return failure();
 }
 
-/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be
-/// added here incrementally in follow-up PRs.
-static FailureOr<LinalgOp>
-inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
-                                                GenericOp genericOp) {
-  return failure();
-}
-
-/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be
-/// added here incrementally in follow-up PRs.
-static FailureOr<LinalgOp>
-inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
-                                                GenericOp genericOp) {
-  return failure();
-}
-
 static FailureOr<LinalgOp>
 inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
                                                 GenericOp genericOp) {
@@ -358,18 +326,10 @@ inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
       genericOp.getIteratorTypesArray();
   unsigned totalIterators = iteratorTypes.size();
   switch (totalIterators) {
-  case 2:
-    return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
   case 4:
     return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
-  case 5:
-    return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
   case 6:
     return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
-  case 7:
-    return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
-  case 8:
-    return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
   case 9:
     return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
   }
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 4dfec7b361eab..23c7fb68a5534 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -401,9 +401,10 @@ static bool updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
   return true;
 }
 
-static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
-                                      SmallVector<int64_t> *dilations,
-                                      SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
   if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
     return true;
 
@@ -432,9 +433,10 @@ static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
                                                     tempDilations, tempStrides);
 }
 
-static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
-                                        SmallVector<int64_t> *dilations,
-                                        SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
   if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
     return true;
 
@@ -466,9 +468,10 @@ static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
                                                     tempDilations, tempStrides);
 }
 
-static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
-                                           SmallVector<int64_t> *dilations,
-                                           SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
   if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
     return true;
 
@@ -507,8 +510,10 @@ static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
                                                     tempDilations, tempStrides);
 }
 
-static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                                SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNhwcMaxOp>(op))
     return true;
 
@@ -543,8 +548,10 @@ static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                                SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNhwcMinOp>(op))
     return true;
 
@@ -579,8 +586,10 @@ static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
-                                SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNhwcSumOp>(op))
     return true;
 
@@ -615,9 +624,10 @@ static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
                                                     tempDilations, tempStrides);
 }
 
-static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
-                                        SmallVector<int64_t> *dilations,
-                                        SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
     return true;
 
@@ -652,9 +662,10 @@ static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
                                                     tempDilations, tempStrides);
 }
 
-static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
-                                        SmallVector<int64_t> *dilations,
-                                        SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
   if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
     return true;
 
@@ -689,62 +700,6 @@ static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
                                                     tempDilations, tempStrides);
 }
 
-template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
-    LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides) {
-  return isaDepthwiseConv1DNwcWcOp(op, dilations, strides);
-}
-
-template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
-    LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides) {
-  return isaDepthwiseConv2DNchwChwOp(op, dilations, strides);
-}
-
-template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
-    LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides) {
-  return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides);
-}
-
-template <>
-bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
-    LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides) {
-  return isaPoolingNhwcMaxOp(op, dilations, strides);
-}
-
-template <>
-bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
-    LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides) {
-  return isaPoolingNhwcMinOp(op, dilations, strides);
-}
-
-template <>
-bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
-    LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides) {
-  return isaPoolingNhwcSumOp(op, dilations, strides);
-}
-
-template <>
-bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
-    LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides) {
-  return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides);
-}
-
-template <>
-bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
-    LinalgOp op, SmallVector<int64_t> *dilations,
-    SmallVector<int64_t> *strides) {
-  return isaPoolingNhwcMinUnsignedOp(op, dilations, strides);
-}
-
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
                             Value source, Value pad, bool nofold,
                             ValueRange typeDynDims) {

>From d44cc34ce67daccce72d930f6fea0982ce02a273 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 23 Oct 2025 06:04:50 -0500
Subject: [PATCH 4/5] Review comment Andrszej v2.0

---
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 54 ++++---------------
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       |  6 +++
 2 files changed, 17 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 2bfa21d9062ee..ce3df6a485f92 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -264,25 +264,26 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
   return namedOp;
 }
 
+// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
+// improve the search speed, the convolution ops have been segregated based on
+// the rank of iterator types array.
 static FailureOr<LinalgOp>
-inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
-                                                GenericOp genericOp) {
+inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
+  // Depthwise Convolution ops.
   if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
           genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
         rewriter, genericOp, dilations, strides);
-  return failure();
-}
-
-static FailureOr<LinalgOp>
-inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
-                                                GenericOp genericOp) {
-  SmallVector<int64_t> dilations, strides;
   if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
           genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
         rewriter, genericOp, dilations, strides);
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+        rewriter, genericOp, dilations, strides);
+  // Pooling ops.
   if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
                                                        &strides))
     return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
@@ -306,36 +307,6 @@ inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
   return failure();
 }
 
-static FailureOr<LinalgOp>
-inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
-                                                GenericOp genericOp) {
-  SmallVector<int64_t> dilations, strides;
-  if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
-          genericOp, &dilations, &strides))
-    return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
-        rewriter, genericOp, dilations, strides);
-  return failure();
-}
-
-// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
-// improve the search speed, the convolution ops have been segregated based on
-// the rank of iterator types array.
-static FailureOr<LinalgOp>
-inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
-  SmallVector<utils::IteratorType> iteratorTypes =
-      genericOp.getIteratorTypesArray();
-  unsigned totalIterators = iteratorTypes.size();
-  switch (totalIterators) {
-  case 4:
-    return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
-  case 6:
-    return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
-  case 9:
-    return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
-  }
-  return failure();
-}
-
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -417,10 +388,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
   }
 
   // Convolution - e.g. *conv/pooling*
-  if (isaConvolutionOpInterface(genericOp)) {
-    return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
-  }
-  return failure();
+  return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
 }
 
 namespace {
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 23c7fb68a5534..cd518fc38819e 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -263,6 +263,9 @@ static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
                                                                   body);
 }
 
+// max_unsigned ops should not allow float data type.
+// TODO: Retire OPDSL logic. Refer to :
+// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337
 static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
   return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
                                                                   body);
@@ -273,6 +276,9 @@ static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
                                                                   body);
 }
 
+// min_unsigned ops should not allow float data type.
+// TODO: Retire OPDSL logic. Refer to :
+// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337
 static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
   return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
                                                                   body);

>From 7b47d9e56db22366604e8608d099002cba5e9fd6 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 24 Oct 2025 02:58:28 -0500
Subject: [PATCH 5/5] Review comment Andrszej v3.0

---
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 13 +++++--
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 39 ++++++++++---------
 2 files changed, 31 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index ce3df6a485f92..c68f7bd88c1ae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -267,10 +267,12 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
 // Converts linalg.generic to named linalg.*conv/pooling* where possible. To
 // improve the search speed, the convolution ops have been segregated based on
 // the rank of iterator types array.
-static FailureOr<LinalgOp>
-inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
+                                                        GenericOp genericOp) {
   SmallVector<int64_t> dilations, strides;
+  // -----------------------------
   // Depthwise Convolution ops.
+  //------------------------------
   if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
           genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
@@ -283,7 +285,9 @@ inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
           genericOp, &dilations, &strides))
     return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
         rewriter, genericOp, dilations, strides);
+  // -----------------------------
   // Pooling ops.
+  //------------------------------
   if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
                                                        &strides))
     return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
@@ -388,7 +392,10 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
   }
 
   // Convolution - e.g. *conv/pooling*
-  return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
+  if (isaConvolutionOpInterface(genericOp)) {
+    return specializeLinalgConvolutions(rewriter, genericOp);
+  }
+  return failure();
 }
 
 namespace {
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index cd518fc38819e..c5c9e4b2f8387 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -265,7 +265,7 @@ static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
 
 // max_unsigned ops should not allow float data type.
 // TODO: Retire OPDSL logic. Refer to :
-// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337
+//       https://github.com/llvm/llvm-project/issues/164800
 static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
   return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
                                                                   body);
@@ -278,7 +278,7 @@ static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
 
 // min_unsigned ops should not allow float data type.
 // TODO: Retire OPDSL logic. Refer to :
-// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337
+//       https://github.com/llvm/llvm-project/issues/164800
 static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
   return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
                                                                   body);
@@ -407,6 +407,9 @@ static bool updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
   return true;
 }
 
+// ---------------------------------------------
+// Matchers for specific convolution operation.
+//----------------------------------------------
 template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
@@ -414,8 +417,8 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
   if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
     return true;
 
-  if (!isaConvolutionOpInterface(op))
-    return false;
+  assert(isaConvolutionOpInterface(op) &&
+         "expected linalgOp to implement ConvolutionOpInterface");
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
   if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3}))
@@ -446,8 +449,8 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
   if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
     return true;
 
-  if (!isaConvolutionOpInterface(op))
-    return false;
+  assert(isaConvolutionOpInterface(op) &&
+         "expected linalgOp to implement ConvolutionOpInterface");
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
   if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4}))
@@ -481,8 +484,8 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
   if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
     return true;
 
-  if (!isaConvolutionOpInterface(op))
-    return false;
+  assert(isaConvolutionOpInterface(op) &&
+         "expected linalgOp to implement ConvolutionOpInterface");
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
   if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6}))
@@ -523,8 +526,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
   if (isa<linalg::PoolingNhwcMaxOp>(op))
     return true;
 
-  if (!isaConvolutionOpInterface(op))
-    return false;
+  assert(isaConvolutionOpInterface(op) &&
+         "expected linalgOp to implement ConvolutionOpInterface");
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
   if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
@@ -561,8 +564,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
   if (isa<linalg::PoolingNhwcMinOp>(op))
     return true;
 
-  if (!isaConvolutionOpInterface(op))
-    return false;
+  assert(isaConvolutionOpInterface(op) &&
+         "expected linalgOp to implement ConvolutionOpInterface");
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
   if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
@@ -599,8 +602,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
   if (isa<linalg::PoolingNhwcSumOp>(op))
     return true;
 
-  if (!isaConvolutionOpInterface(op))
-    return false;
+  assert(isaConvolutionOpInterface(op) &&
+         "expected linalgOp to implement ConvolutionOpInterface");
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
   if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
@@ -637,8 +640,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
   if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
     return true;
 
-  if (!isaConvolutionOpInterface(op))
-    return false;
+  assert(isaConvolutionOpInterface(op) &&
+         "expected linalgOp to implement ConvolutionOpInterface");
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
   if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
@@ -675,8 +678,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
   if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
     return true;
 
-  if (!isaConvolutionOpInterface(op))
-    return false;
+  assert(isaConvolutionOpInterface(op) &&
+         "expected linalgOp to implement ConvolutionOpInterface");
 
   ArrayAttr indexingMaps = op.getIndexingMaps();
   if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))



More information about the Mlir-commits mailing list