[Mlir-commits] [mlir] [mlir][linalg] Use inferConvolutionDims for generic convolution downscaling (PR #180586)

Abhishek Varma llvmlistbot at llvm.org
Mon Feb 23 06:31:06 PST 2026


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/180586

>From 364200bbf2363767f348c660861e179ccff72b14 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Mon, 9 Feb 2026 16:14:15 +0000
Subject: [PATCH 1/5] [mlir][linalg] Use inferConvolutionDims for generic
 convolution downscaling

Refactor convolution downscaling to use inferConvolutionDims instead of
hardcoded dimension indices. This enables a single implementation to
handle all convolution-like ops (convolutions, depthwise convs, pooling)
regardless of layout, and works uniformly on both named ops and
equivalent linalg.generic ops.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../Dialect/Linalg/Transforms/Transforms.h    |  56 +--
 .../TransformOps/LinalgTransformOps.cpp       |  32 +-
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 450 ++++++++----------
 .../Linalg/transform-op-decompose.mlir        |  76 +--
 .../transform-op-peel-and-vectorize-conv.mlir |   6 +-
 5 files changed, 249 insertions(+), 371 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 32067358438d3..544336fb64bb6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1640,54 +1640,22 @@ FailureOr<linalg::GenericOp> deduplicateOperandsAndRemoveDeadResults(
 // functional-stye API call.
 //===----------------------------------------------------------------------===//
 
-/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
-/// convolution ops. Works with both named ops and equivalent generic ops.
-template <typename Conv2DOp, typename Conv1DOp>
-struct DownscaleSizeOneWindowed2DConvolution final
+/// Rewrite 2-D convolution/pooling/depthwise ops with size-1 window dimensions
+/// into lower-dimensional linalg.generic ops.
+/// Handles both named ops and equivalent linalg.generic ops uniformly.
+FailureOr<linalg::GenericOp>
+downscaleSizeOneWindowedConvolution(RewriterBase &rewriter, LinalgOp op);
+
+/// Pattern wrapper around `downscaleSizeOneWindowedConvolution`.
+struct DownscaleSizeOneWindowedConvolution final
     : public OpInterfaceRewritePattern<LinalgOp> {
-  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
-
-  FailureOr<Conv1DOp> returningMatchAndRewrite(LinalgOp convOp,
-                                               PatternRewriter &rewriter) const;
-
-  LogicalResult matchAndRewrite(LinalgOp convOp,
-                                PatternRewriter &rewriter) const override {
-    return returningMatchAndRewrite(convOp, rewriter);
-  }
-};
-
-extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
-                                                             Conv1DNwcWcfOp>;
-extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
-                                                             Conv1DNcwFcwOp>;
-
-/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
-/// dimensions into 1-D depthwise convolution ops.
-struct DownscaleDepthwiseConv2DNhwcHwcOp final
-    : public OpInterfaceRewritePattern<LinalgOp> {
-  DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context,
-                                    PatternBenefit benefit = 1)
+  DownscaleSizeOneWindowedConvolution(MLIRContext *context,
+                                      PatternBenefit benefit = 1)
       : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
 
-  FailureOr<DepthwiseConv1DNwcWcOp>
-  returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const;
-
-  LogicalResult matchAndRewrite(LinalgOp convOp,
-                                PatternRewriter &rewriter) const override {
-    return returningMatchAndRewrite(convOp, rewriter);
-  }
-};
-
-struct DownscaleConv2DOp final : public OpInterfaceRewritePattern<LinalgOp> {
-  DownscaleConv2DOp(MLIRContext *context, PatternBenefit benefit = 1)
-      : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
-
-  FailureOr<Conv1DOp> returningMatchAndRewrite(LinalgOp convOp,
-                                               PatternRewriter &rewriter) const;
-
-  LogicalResult matchAndRewrite(LinalgOp convOp,
+  LogicalResult matchAndRewrite(LinalgOp op,
                                 PatternRewriter &rewriter) const override {
-    return returningMatchAndRewrite(convOp, rewriter);
+    return downscaleSizeOneWindowedConvolution(rewriter, op);
   }
 };
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1d33c1e85376e..332f9f02c3e3c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -484,32 +484,12 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
                                    LinalgOp target,
                                    transform::ApplyToEachResultList &results,
                                    transform::TransformState &state) {
-#define DOWNSCALE(trans)                                                       \
-  {                                                                            \
-    FailureOr<LinalgOp> res = tryApply<trans>(target);                         \
-    if (succeeded(res)) {                                                      \
-      results.push_back(*res);                                                 \
-      return DiagnosedSilenceableFailure::success();                           \
-    }                                                                          \
-  }
-
-#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
-#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
-
-  DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
-  DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
-  DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
-  DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
-  DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
-  DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
-  DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
-  DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
-  DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
-  DOWNSCALE(DownscaleDepthwiseConv2DNhwcHwcOp)
-  DOWNSCALE(DownscaleConv2DOp)
-#undef DOWNSCALE_NORMAL
-#undef DOWNSCALE_CALL
-#undef DOWNSCALE
+  FailureOr<linalg::GenericOp> res =
+      downscaleSizeOneWindowedConvolution(rewriter, target);
+  if (succeeded(res)) {
+    results.push_back(*res);
+    return DiagnosedSilenceableFailure::success();
+  }
   return emitDefaultSilenceableFailure(target);
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index eb3eb48a7fe34..4e1d7e433cc5f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1422,289 +1422,215 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   return success();
 }
 
-// The following are patterns for downscaling convolution ops with size-1
-// window dimensions.
+//===----------------------------------------------------------------------===//
+// Generic DownscaleSizeOneWindowedConvolution
+//===----------------------------------------------------------------------===//
 //
-// Note that we'd eventually want to write such transformations in a generic
-// way, e.g., converting to linalg.generic, removing the size-1 dimensions,
-// and then turning back to named ops. But for now it's fine to have a few
-// patterns matching special ops to get started.
-
-template <typename Conv2DOp, typename Conv1DOp>
-FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
-    returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const {
-  // Check if this LinalgOp is of the expected Conv2DOp type (named or generic).
-  std::optional<DilationsAndStrides> convParams =
-      matchConvolutionOpOfType<Conv2DOp>(convOp);
-  if (!convParams)
-    return failure();
-  SmallVector<int64_t> dilations = std::move(convParams->dilations);
-  SmallVector<int64_t> strides = std::move(convParams->strides);
-
-  if (convOp.hasPureBufferSemantics())
-    return failure(); // To be implemented.
-
-  Value input = convOp.getDpsInputs().front();
-  Value kernel = convOp.getDpsInputs().back();
-  Value output = convOp.getDpsInits().front();
-
-  auto inputType = dyn_cast<RankedTensorType>(input.getType());
-  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
-  auto outputType = dyn_cast<RankedTensorType>(output.getType());
-
-  auto kernelShape = kernelType.getShape();
-  auto outputShape = outputType.getShape();
-
-  // Get domain indices based on Conv2DOp type. These are known at compile time.
-  int64_t khIndex, kwIndex, ohIndex, owIndex;
-  if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNhwcHwcfOp> ||
-                std::is_same_v<Conv2DOp, linalg::PoolingNhwcSumOp> ||
-                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxOp> ||
-                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxUnsignedOp> ||
-                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinOp> ||
-                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinUnsignedOp>) {
-    // NHWC layout: kernel [H, W, ...], output [N, H, W, C]
-    khIndex = 0;
-    kwIndex = 1;
-    ohIndex = 1;
-    owIndex = 2;
-  } else if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNchwFchwOp>) {
-    // NCHW_FCHW layout: kernel [..., H, W], output [N, C, H, W]
-    khIndex = 2;
-    kwIndex = 3;
-    ohIndex = 2;
-    owIndex = 3;
-  } else if constexpr (std::is_same_v<Conv2DOp, linalg::PoolingNchwSumOp> ||
-                       std::is_same_v<Conv2DOp, linalg::PoolingNchwMaxOp>) {
-    // NCHW pooling layout: kernel [H, W], output [N, C, H, W]
-    khIndex = 0;
-    kwIndex = 1;
-    ohIndex = 2;
-    owIndex = 3;
+// This pattern rewrites 2-D convolution/pooling/depthwise ops with size-1
+// window dimensions into lower-dimensional ops. It uses inferConvolutionDims
+// to work with any layout and handles both named ops and equivalent
+// linalg.generic ops uniformly.
+//
+/// Returns the indices of affine map results that reference any of the given
+/// dimensions.
+static SmallVector<unsigned>
+getResultIndicesReferencingDims(AffineMap map, ArrayRef<unsigned> dims) {
+  SmallVector<unsigned> resultIndices;
+  for (unsigned dim : dims) {
+    for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
+      AffineExpr expr = map.getResult(i);
+      if (expr.isFunctionOfDim(dim)) {
+        resultIndices.push_back(i);
+        break;
+      }
+    }
   }
+  return resultIndices;
+}
 
-  // Only handle the case where at least one of the window dimensions is
-  // of size 1. Other cases can rely on tiling to reduce to such cases.
-  int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
-  int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
-  bool removeH = (khSize == 1 && ohSize == 1);
-  bool removeW = (kwSize == 1 && owSize == 1);
-  if (!removeH && !removeW)
-    return failure();
+/// Helper to create a rank-reducing extract_slice that removes specific
+/// dimensions from a tensor.
+static Value createRankReducingExtractSlice(RewriterBase &rewriter,
+                                            Location loc, Value tensor,
+                                            ArrayRef<unsigned> dimsToRemove) {
+  auto tensorType = cast<RankedTensorType>(tensor.getType());
+  int64_t rank = tensorType.getRank();
+
+  // Compute new shape by removing the specified dimensions.
+  SmallVector<int64_t> newShape;
+  for (int64_t i = 0; i < rank; ++i) {
+    if (!llvm::is_contained(dimsToRemove, i))
+      newShape.push_back(tensorType.getDimSize(i));
+  }
 
-  // Get new shapes and types for all operands by removing the size-1
-  // dimension.
-  using RTTBuilder = RankedTensorType::Builder;
-  RankedTensorType newInputType =
-      RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
-  RankedTensorType newKernelType =
-      RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
-  RankedTensorType newOutputType =
-      RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
-
-  // Rank-reduce operands.
-  Location loc = convOp.getLoc();
-  Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, input, newInputType);
-  Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, kernel, newKernelType);
-  Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, output, newOutputType);
-
-  // Rank-reduce strides and dilations too.
-  // TODO: dropDim 1-liner helper.
-  strides.erase(strides.begin() + (removeH ? 0 : 1));
-  auto stridesAttr = rewriter.getI64VectorAttr(strides);
-
-  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
-  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
-
-  auto conv1DOp = Conv1DOp::create(
-      rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
-      ValueRange{newOutput}, stridesAttr, dilationsAttr);
-
-  // Insert back.
-  Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
-      rewriter, loc, conv1DOp.getResult(0), output);
-  rewriter.replaceOp(convOp, inserted);
-
-  return conv1DOp;
+  auto newType = RankedTensorType::get(newShape, tensorType.getElementType());
+  return tensor::createCanonicalRankReducingExtractSliceOp(rewriter, loc,
+                                                           tensor, newType);
 }
 
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
-                                                              Conv1DNwcWcfOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
-                                                              Conv1DNcwFcwOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
-                                                              PoolingNwcSumOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
-                                                              PoolingNcwSumOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
-                                                              PoolingNwcMaxOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<
-    PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
-                                                              PoolingNwcMinOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<
-    PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
-                                                              PoolingNcwMaxOp>;
-
-FailureOr<DepthwiseConv1DNwcWcOp>
-DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
-    LinalgOp convOp, PatternRewriter &rewriter) const {
-  // Check if this LinalgOp is a DepthwiseConv2DNhwcHwcOp (named or generic).
-  std::optional<DilationsAndStrides> convParams =
-      matchConvolutionOpOfType<DepthwiseConv2DNhwcHwcOp>(convOp);
-  if (!convParams)
+/// Drops specified dimensions from an AffineExpr and compresses remaining
+/// dimension indices. Returns std::nullopt if the expression only references
+/// the dropped dimensions.
+static std::optional<AffineExpr>
+dropDimsAndCompress(AffineExpr expr, ArrayRef<unsigned> dimsToDrop,
+                    unsigned newNumDims, MLIRContext *ctx) {
+  // Check if expr only references dimensions to be dropped.
+  bool onlyReferencesDroppedDims = true;
+  for (unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
+    if (expr.isFunctionOfDim(d) && !llvm::is_contained(dimsToDrop, d)) {
+      onlyReferencesDroppedDims = false;
+      break;
+    }
+  }
+  if (onlyReferencesDroppedDims && llvm::any_of(dimsToDrop, [&](unsigned d) {
+        return expr.isFunctionOfDim(d);
+      }))
+    return std::nullopt;
+
+  // Replace dimensions: compute new index for each old dimension.
+  // Dropped dimensions get mapped to constant 0, others get compressed.
+  SmallVector<AffineExpr> dimReplacements;
+  unsigned newDimIdx = 0;
+  for (unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
+    if (llvm::is_contained(dimsToDrop, d)) {
+      dimReplacements.push_back(getAffineConstantExpr(0, ctx));
+    } else {
+      dimReplacements.push_back(getAffineDimExpr(newDimIdx++, ctx));
+    }
+  }
+
+  return expr.replaceDims(dimReplacements);
+}
+
+FailureOr<linalg::GenericOp>
+linalg::downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
+                                            LinalgOp op) {
+  auto maybeDims = inferConvolutionDims(op);
+  if (failed(maybeDims))
     return failure();
-  SmallVector<int64_t> dilations = std::move(convParams->dilations);
-  SmallVector<int64_t> strides = std::move(convParams->strides);
-
-  if (convOp.hasPureBufferSemantics())
-    return failure(); // To be implemented.
-
-  Value input = convOp.getDpsInputs().front();
-  Value kernel = convOp.getDpsInputs().back();
-  Value output = convOp.getDpsInits().front();
-
-  auto inputType = dyn_cast<RankedTensorType>(input.getType());
-  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
-  auto outputType = dyn_cast<RankedTensorType>(output.getType());
-
-  auto kernelShape = kernelType.getShape();
-  auto outputShape = outputType.getShape();
-
-  // Only handle the case where at least one of the window dimensions is
-  // of size 1. Other cases can rely on tiling to reduce to such cases.
-  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
-  int64_t ohSize = outputShape[1], owSize = outputShape[2];
-  bool removeH = (khSize == 1 && ohSize == 1);
-  bool removeW = (kwSize == 1 && owSize == 1);
-  if (!removeH && !removeW)
+
+  // Must be 2D Conv.
+  if (maybeDims->outputImage.size() != 2 || maybeDims->filterLoop.size() != 2)
     return failure();
 
-  // Get new shapes and types for all operands by removing the size-1
-  // dimension.
-  using RTTBuilder = RankedTensorType::Builder;
-  RankedTensorType newInputType =
-      RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
-  RankedTensorType newKernelType =
-      RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
-  RankedTensorType newOutputType =
-      RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
-
-  // Rank-reduce operands.
-  Location loc = convOp.getLoc();
-  Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, input, newInputType);
-  Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, kernel, newKernelType);
-  Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, output, newOutputType);
-
-  // Rank-reduce strides and dilations too.
-  // TODO: dropDim 1-liner helper.
-  strides.erase(strides.begin() + (removeH ? 0 : 1));
-  auto stridesAttr = rewriter.getI64VectorAttr(strides);
-
-  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
-  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
-
-  auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
-      rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
-      ValueRange{newOutput}, stridesAttr, dilationsAttr);
-
-  // Insert back.
-  Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
-      rewriter, loc, conv1DOp.getResult(0), output);
-  rewriter.replaceOp(convOp, inserted);
-
-  return conv1DOp;
-}
+  if (op.hasPureBufferSemantics())
+    return failure();
 
-FailureOr<Conv1DOp>
-DownscaleConv2DOp::returningMatchAndRewrite(LinalgOp convOp,
-                                            PatternRewriter &rewriter) const {
-  // Check if this LinalgOp is a Conv2DOp (named or generic).
-  std::optional<DilationsAndStrides> convParams =
-      matchConvolutionOpOfType<Conv2DOp>(convOp);
-  if (!convParams)
+  // 1. Get loop domain indices.
+  unsigned ohLoopIdx = maybeDims->outputImage[0];
+  unsigned owLoopIdx = maybeDims->outputImage[1];
+  unsigned khLoopIdx = maybeDims->filterLoop[0];
+  unsigned kwLoopIdx = maybeDims->filterLoop[1];
+
+  // 2. Get sizes from loop bounds.
+  SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
+  int64_t ohSize = loopRanges[ohLoopIdx];
+  int64_t owSize = loopRanges[owLoopIdx];
+  int64_t khSize = loopRanges[khLoopIdx];
+  int64_t kwSize = loopRanges[kwLoopIdx];
+
+  // 3. Check if we can downscale.
+  bool canRemoveH = (khSize == 1 && ohSize == 1);
+  bool canRemoveW = (kwSize == 1 && owSize == 1);
+  if (!canRemoveH && !canRemoveW)
     return failure();
 
-  if (convOp.hasPureBufferSemantics())
-    return failure(); // To be implemented.
+  // Prefer removing H if both are possible.
+  bool removeH = canRemoveH;
 
-  Value input = convOp.getDpsInputs().front();
-  Value kernel = convOp.getDpsInputs().back();
-  Value output = convOp.getDpsInits().front();
+  // Determine which loop dims to remove (output spatial + corresponding filter)
+  SmallVector<unsigned> loopDimsToRemove;
+  if (removeH) {
+    loopDimsToRemove.push_back(ohLoopIdx);
+    loopDimsToRemove.push_back(khLoopIdx);
+  } else {
+    loopDimsToRemove.push_back(owLoopIdx);
+    loopDimsToRemove.push_back(kwLoopIdx);
+  }
+  // Sort for correct index compression when removing dimensions from affine
+  // maps.
+  llvm::sort(loopDimsToRemove);
 
-  auto inputType = dyn_cast<RankedTensorType>(input.getType());
-  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
-  auto outputType = dyn_cast<RankedTensorType>(output.getType());
+  // 4. Create new indexing maps with dimensions to be removed.
+  SmallVector<AffineMap> newMaps;
+  MLIRContext *ctx = op.getContext();
+  unsigned numDims = op.getNumLoops();
+  unsigned newNumDims = numDims - loopDimsToRemove.size();
+
+  for (AffineMap map : op.getIndexingMapsArray()) {
+    // Remove the loop dimensions from the map.
+    SmallVector<AffineExpr> newResults;
+    for (AffineExpr expr : map.getResults()) {
+      auto newExpr =
+          dropDimsAndCompress(expr, loopDimsToRemove, newNumDims, ctx);
+      if (newExpr)
+        newResults.push_back(*newExpr);
+    }
+    newMaps.push_back(AffineMap::get(newNumDims, 0, newResults, ctx));
+  }
 
-  auto kernelShape = kernelType.getShape();
-  auto outputShape = outputType.getShape();
+  // 5. Create new iterator types.
+  SmallVector<utils::IteratorType> newIterTypes;
+  auto iterTypes = op.getIteratorTypesArray();
+  for (unsigned idx = 0; idx < iterTypes.size(); ++idx) {
+    if (!llvm::is_contained(loopDimsToRemove, idx))
+      newIterTypes.push_back(iterTypes[idx]);
+  }
 
-  // Only handle the case where at least one of the window dimensions is
-  // of size 1. Other cases can rely on tiling to reduce to such cases.
-  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
-  int64_t ohSize = outputShape[0], owSize = outputShape[1];
-  bool removeH = (khSize == 1 && ohSize == 1);
-  bool removeW = (kwSize == 1 && owSize == 1);
-  if (!removeH && !removeW)
-    return failure();
+  // 6. Rank-reduce operands using extract_slice.
+  Location loc = op.getLoc();
+  SmallVector<Value> newInputs;
+  for (OpOperand *input : op.getDpsInputOperands()) {
+    AffineMap map = op.getMatchingIndexingMap(input);
+    SmallVector<unsigned> tensorDimsToRemove =
+        getResultIndicesReferencingDims(map, loopDimsToRemove);
+    Value reduced = createRankReducingExtractSlice(rewriter, loc, input->get(),
+                                                   tensorDimsToRemove);
+    newInputs.push_back(reduced);
+  }
+
+  SmallVector<Value> newOutputs;
+  Value originalOutput;
+  SmallVector<OpOperand *> initOperands =
+      llvm::to_vector(llvm::make_pointer_range(op.getDpsInitsMutable()));
+  for (OpOperand *output : initOperands) {
+    originalOutput = output->get();
+    AffineMap map = op.getMatchingIndexingMap(output);
+    SmallVector<unsigned> tensorDimsToRemove =
+        getResultIndicesReferencingDims(map, loopDimsToRemove);
+    Value reduced = createRankReducingExtractSlice(rewriter, loc, output->get(),
+                                                   tensorDimsToRemove);
+    newOutputs.push_back(reduced);
+  }
 
-  // Get new shapes and types for all operands by removing the size-1
-  // dimension.
-  using RTTBuilder = RankedTensorType::Builder;
-  RankedTensorType newInputType =
-      RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
-  RankedTensorType newKernelType =
-      RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
-  RankedTensorType newOutputType =
-      RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
-
-  // Rank-reduce operands.
-  Location loc = convOp.getLoc();
-  Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, input, newInputType);
-  Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, kernel, newKernelType);
-  Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, output, newOutputType);
-
-  auto conv1DOp =
-      Conv1DOp::create(rewriter, loc, newOutputType,
-                       ValueRange{newInput, newKernel}, ValueRange{newOutput});
-
-  // Insert back.
-  Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
-      rewriter, loc, conv1DOp.getResult(0), output);
-  rewriter.replaceOp(convOp, inserted);
-
-  return conv1DOp;
+  // 7. Create new linalg.generic with reduced dimensions
+  auto newOp = linalg::GenericOp::create(
+      rewriter, loc, TypeRange{newOutputs[0].getType()}, newInputs, newOutputs,
+      newMaps, newIterTypes,
+      [&](OpBuilder &b, Location nestedLoc, ValueRange args) {
+        IRMapping mapping;
+        for (auto [oldArg, newArg] :
+             llvm::zip(op.getBlock()->getArguments(), args))
+          mapping.map(oldArg, newArg);
+        for (Operation &bodyOp : op.getBlock()->without_terminator())
+          b.clone(bodyOp, mapping);
+        auto yield = cast<linalg::YieldOp>(op.getBlock()->getTerminator());
+        linalg::YieldOp::create(b, nestedLoc,
+                                mapping.lookup(yield.getOperand(0)));
+      });
+
+  // 8. Insert result back into original shape.
+  Value result = tensor::createCanonicalRankReducingInsertSliceOp(
+      rewriter, loc, newOp.getResult(0), originalOutput);
+
+  rewriter.replaceOp(op, result);
+  return newOp;
 }
 
 void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
                                                   PatternBenefit benefit) {
-  patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
-                                                     Conv1DNwcWcfOp>,
-               DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
-                                                     Conv1DNcwFcwOp>,
-               DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>(
-      patterns.getContext(), benefit);
-  patterns.add<
-      DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>,
-      DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>,
-      DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>,
-      DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
-                                            PoolingNwcMaxUnsignedOp>,
-      DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>,
-      DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
-                                            PoolingNwcMinUnsignedOp>,
-      DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
-      patterns.getContext(), benefit);
+  patterns.add<DownscaleSizeOneWindowedConvolution>(patterns.getContext(),
+                                                    benefit);
 }
 
 void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 6b03885069a37..e05ea6d815f76 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt --transform-interpreter --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --transform-interpreter --linalg-specialize-generic-ops --split-input-file %s | FileCheck %s
 // Test the same patterns on generic convolution ops by first generalizing the
 // named ops. This avoids duplicating lit tests for linalg.generic conv ops.
-// RUN: mlir-opt --linalg-generalize-named-ops --transform-interpreter --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --linalg-generalize-named-ops --transform-interpreter --linalg-specialize-generic-ops --split-input-file %s | FileCheck %s
 
 // CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -42,10 +42,11 @@ func.func @conv_2d_nchw_fchw(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x?x1x
   return %0 : tensor<?x?x1x?xf32>
 }
 
-// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
+// Depthwise conv with height=1 (downscales height dimension)
+// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc_height
 // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>
-func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> {
+func.func @depthwise_conv_2d_nhwc_hwc_height(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> {
   // CHECK: %[[RES:.+]] = tensor.empty
   %init = tensor.empty() : tensor<1x1x56x96xf32>
   // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
@@ -62,6 +63,27 @@ func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: t
   return %0: tensor<1x1x56x96xf32>
 }
 
+// Depthwise conv with width=1 (downscales width dimension)
+// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc_width
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x113x1x96xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<3x1x96xf32>
+func.func @depthwise_conv_2d_nhwc_hwc_width(%input: tensor<1x113x1x96xf32>, %filter: tensor<3x1x96xf32>) -> tensor<1x56x1x96xf32> {
+  // CHECK: %[[RES:.+]] = tensor.empty
+  %init = tensor.empty() : tensor<1x56x1x96xf32>
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICERES:.+]] = tensor.extract_slice %[[RES]]
+  // CHECK: %[[OPRES:.+]] = linalg.depthwise_conv_1d_nwc_wc
+  // CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]]
+  // CHECK-SAME: outs(%[[SLICERES]]
+  // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[OPRES]] into %[[RES]]
+  %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+         ins(%input, %filter: tensor<1x113x1x96xf32>, tensor<3x1x96xf32>)
+         outs(%init: tensor<1x56x1x96xf32>) -> tensor<1x56x1x96xf32>
+  // CHECK: %[[INSERTED]]
+  return %0: tensor<1x56x1x96xf32>
+}
+
 // CHECK-LABEL: @conv_2d
 // CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<1x?xf32>,
 // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
@@ -212,39 +234,19 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
 
 // CHECK-LABEL:      func.func @softmax(
 // CHECK-SAME:           %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
-// CHECK-DAG:        %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
-// CHECK-DAG:        %[[CST:.+]] = arith.constant 0xFFC00000 : f32
-// CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
-// CHECK:        %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME:     "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
-// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:          %[[D8:.+]] = arith.maxnumf %[[IN]], %[[OUT]] : f32
-// CHECK:          linalg.yield %[[D8]] : f32
-// CHECK:        } -> tensor<2x16xf32>
-// CHECK:        %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
-// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
-// CHECK-SAME:     outs(%[[DST]] : tensor<2x16x32xf32>) {
-// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:          %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32
-// CHECK:          %[[D9:.+]] = math.exp %[[D8]] : f32
-// CHECK:          linalg.yield %[[D9]] : f32
-// CHECK:        } -> tensor<2x16x32xf32>
-// CHECK:        %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK:        %[[D5:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
-// CHECK:        %[[D6:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME:     "parallel", "reduction"]} ins(%[[D4]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) {
-// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:          %[[D8]] = arith.addf %[[IN]], %[[OUT]] : f32
-// CHECK:          linalg.yield %[[D8]] : f32
-// CHECK:        } -> tensor<2x16xf32>
-// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
-// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
-// CHECK-SAME:     outs(%[[DST]] : tensor<2x16x32xf32>) {
-// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:          %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32
-// CHECK:          linalg.yield %[[D8]] : f32
-// CHECK:        } -> tensor<2x16x32xf32>
-// CHECK:        return %[[D7]] : tensor<2x16x32xf32>
+// CHECK:        linalg.fill
+// CHECK:        linalg.generic
+// CHECK:          arith.maxnumf
+// CHECK:        linalg.broadcast
+// CHECK:        linalg.generic
+// CHECK:          arith.subf
+// CHECK:          math.exp
+// CHECK:        linalg.fill
+// CHECK:        linalg.generic
+// CHECK:          arith.addf
+// CHECK:        linalg.broadcast
+// CHECK:        linalg.div
+// CHECK:        return
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
diff --git a/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir b/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir
index 4660cc75a1940..dfd502576d290 100644
--- a/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --transform-interpreter --split-input-file -resolve-shaped-type-result-dims -canonicalize | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter --linalg-specialize-generic-ops --split-input-file -resolve-shaped-type-result-dims -canonicalize | FileCheck %s
 
 // Demonstrates what happens when peeling the 4th loop (that corresponds to the
 // "depth" dimension in depthwise convs) followed by vectorization in the
@@ -73,7 +73,9 @@ module attributes {transform.with_named_sequence} {
 
     // 4. Apply loop peeling - only the 4th loop
     %main_loop, %remainder_loop = transform.loop.peel %loops_1#3 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">, !transform.op<"scf.for">)
-    %5 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %main_loop : (!transform.op<"scf.for">) -> !transform.any_op
+    // Match linalg.generic since decompose produces generic ops, and
+    // --linalg-specialize-generic-ops runs after the transform interpreter.
+    %5 = transform.structured.match ops{["linalg.generic"]} in %main_loop : (!transform.op<"scf.for">) -> !transform.any_op
 
     // 5. Vectorize, but only the main loop
     transform.structured.vectorize %5 vector_sizes [2, 4, [4], 16] : !transform.any_op

>From 03d3d8d35d56aaed33cf3e288a71366dff875c0b Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 10 Feb 2026 05:39:03 +0000
Subject: [PATCH 2/5] Batchless examples + specialization within

---
 .../Dialect/Linalg/Transforms/Transforms.h    |   9 +-
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  28 ++--
 .../Linalg/transform-op-decompose.mlir        | 152 ++++++++++++++++--
 .../transform-op-peel-and-vectorize-conv.mlir |   6 +-
 4 files changed, 161 insertions(+), 34 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 544336fb64bb6..6f1f8be0d71d9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1641,10 +1641,11 @@ FailureOr<linalg::GenericOp> deduplicateOperandsAndRemoveDeadResults(
 //===----------------------------------------------------------------------===//
 
 /// Rewrite 2-D convolution/pooling/depthwise ops with size-1 window dimensions
-/// into lower-dimensional linalg.generic ops.
-/// Handles both named ops and equivalent linalg.generic ops uniformly.
-FailureOr<linalg::GenericOp>
-downscaleSizeOneWindowedConvolution(RewriterBase &rewriter, LinalgOp op);
+/// into lower-dimensional ops. Uses `inferConvolutionDims` to work with any
+/// layout and handles both named ops and equivalent linalg.generic ops
+/// uniformly. The result is specialized back to a named op when possible.
+FailureOr<LinalgOp> downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
+                                                        LinalgOp op);
 
 /// Pattern wrapper around `downscaleSizeOneWindowedConvolution`.
 struct DownscaleSizeOneWindowedConvolution final
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4e1d7e433cc5f..cad1f75d4d344 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1502,7 +1502,7 @@ dropDimsAndCompress(AffineExpr expr, ArrayRef<unsigned> dimsToDrop,
   return expr.replaceDims(dimReplacements);
 }
 
-FailureOr<linalg::GenericOp>
+FailureOr<LinalgOp>
 linalg::downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
                                             LinalgOp op) {
   auto maybeDims = inferConvolutionDims(op);
@@ -1516,20 +1516,20 @@ linalg::downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
   if (op.hasPureBufferSemantics())
     return failure();
 
-  // 1. Get loop domain indices.
+  // Get loop domain indices.
   unsigned ohLoopIdx = maybeDims->outputImage[0];
   unsigned owLoopIdx = maybeDims->outputImage[1];
   unsigned khLoopIdx = maybeDims->filterLoop[0];
   unsigned kwLoopIdx = maybeDims->filterLoop[1];
 
-  // 2. Get sizes from loop bounds.
+  // Get sizes from loop bounds.
   SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
   int64_t ohSize = loopRanges[ohLoopIdx];
   int64_t owSize = loopRanges[owLoopIdx];
   int64_t khSize = loopRanges[khLoopIdx];
   int64_t kwSize = loopRanges[kwLoopIdx];
 
-  // 3. Check if we can downscale.
+  // Check if we can downscale.
   bool canRemoveH = (khSize == 1 && ohSize == 1);
   bool canRemoveW = (kwSize == 1 && owSize == 1);
   if (!canRemoveH && !canRemoveW)
@@ -1551,7 +1551,7 @@ linalg::downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
   // maps.
   llvm::sort(loopDimsToRemove);
 
-  // 4. Create new indexing maps with dimensions to be removed.
+  // Create new indexing maps with dimensions removed.
   SmallVector<AffineMap> newMaps;
   MLIRContext *ctx = op.getContext();
   unsigned numDims = op.getNumLoops();
@@ -1569,7 +1569,7 @@ linalg::downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
     newMaps.push_back(AffineMap::get(newNumDims, 0, newResults, ctx));
   }
 
-  // 5. Create new iterator types.
+  // Create new iterator types.
   SmallVector<utils::IteratorType> newIterTypes;
   auto iterTypes = op.getIteratorTypesArray();
   for (unsigned idx = 0; idx < iterTypes.size(); ++idx) {
@@ -1577,7 +1577,7 @@ linalg::downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
       newIterTypes.push_back(iterTypes[idx]);
   }
 
-  // 6. Rank-reduce operands using extract_slice.
+  // Rank-reduce operands using extract_slice.
   Location loc = op.getLoc();
   SmallVector<Value> newInputs;
   for (OpOperand *input : op.getDpsInputOperands()) {
@@ -1603,7 +1603,7 @@ linalg::downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
     newOutputs.push_back(reduced);
   }
 
-  // 7. Create new linalg.generic with reduced dimensions
+  // Create new linalg.generic with reduced dimensions.
   auto newOp = linalg::GenericOp::create(
       rewriter, loc, TypeRange{newOutputs[0].getType()}, newInputs, newOutputs,
       newMaps, newIterTypes,
@@ -1619,12 +1619,18 @@ linalg::downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
                                 mapping.lookup(yield.getOperand(0)));
       });
 
-  // 8. Insert result back into original shape.
+  // Try to specialize the generic back to a named op if possible.
+  LinalgOp resultOp = newOp;
+  FailureOr<LinalgOp> specializedOp = specializeGenericOp(rewriter, newOp);
+  if (succeeded(specializedOp))
+    resultOp = *specializedOp;
+
+  // Insert result back into original shape.
   Value result = tensor::createCanonicalRankReducingInsertSliceOp(
-      rewriter, loc, newOp.getResult(0), originalOutput);
+      rewriter, loc, resultOp->getResult(0), originalOutput);
 
   rewriter.replaceOp(op, result);
-  return newOp;
+  return resultOp;
 }
 
 void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index e05ea6d815f76..0f4a6f4c05d89 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -1,7 +1,20 @@
-// RUN: mlir-opt --transform-interpreter --linalg-specialize-generic-ops --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --transform-interpreter --split-input-file %s | FileCheck %s
 // Test the same patterns on generic convolution ops by first generalizing the
 // named ops. This avoids duplicating lit tests for linalg.generic conv ops.
-// RUN: mlir-opt --linalg-generalize-named-ops --transform-interpreter --linalg-specialize-generic-ops --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --linalg-generalize-named-ops --transform-interpreter --split-input-file %s | FileCheck %s
+
+// Expected indexing maps for batchless conv_1d_nwc_wcf.
+// CHECK-DAG:  #[[$CONV_I:.+]] = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d3)>
+// CHECK-DAG:  #[[$CONV_F:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>
+// CHECK-DAG:  #[[$CONV_O:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+
+// Expected indexing maps for batchless depthwise_conv_1d_wc_wcf.
+// CHECK-DAG:  #[[$DW_I:.+]] = affine_map<(d0, d1, d2) -> (d0 + d2, d1)>
+// CHECK-DAG:  #[[$DW_F:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+
+// Expected indexing maps for batchless pooling_cw_min.
+// CHECK-DAG:  #[[$POOL_I:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
+// CHECK-DAG:  #[[$POOL_F:.+]] = affine_map<(d0, d1, d2) -> (d2)>
 
 // CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -227,6 +240,95 @@ func.func @pooling_nchw_max(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32
   return %0 : tensor<?x?x1x?xf32>
 }
 
+#map_conv_i = affine_map<(oh, ow, f, kh, kw, c) -> (oh + kh, ow + kw, c)>
+#map_conv_f = affine_map<(oh, ow, f, kh, kw, c) -> (kh, kw, c, f)>
+#map_conv_o = affine_map<(oh, ow, f, kh, kw, c) -> (oh, ow, f)>
+
+// CHECK-LABEL: @batchless_conv_2d_hwc_hwcf
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<1x14x8xf32>
+// CHECK-SAME:    %[[ARG1:.+]]: tensor<1x3x8x16xf32>
+// CHECK-SAME:    %[[ARG2:.+]]: tensor<1x12x16xf32>
+func.func @batchless_conv_2d_hwc_hwcf(%input: tensor<1x14x8xf32>, %filter: tensor<1x3x8x16xf32>, %output: tensor<1x12x16xf32>) -> tensor<1x12x16xf32> {
+  // CHECK:       %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK:       %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK:       %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK:       %[[SLICERES:.+]] = linalg.generic
+  // CHECK-SAME:    indexing_maps = [#[[$CONV_I]], #[[$CONV_F]], #[[$CONV_O]]]
+  // CHECK-SAME:    iterator_types = ["parallel", "parallel", "reduction", "reduction"]
+  // CHECK:       %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  %0 = linalg.generic {
+    indexing_maps = [#map_conv_i, #map_conv_f, #map_conv_o],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
+  } ins(%input, %filter : tensor<1x14x8xf32>, tensor<1x3x8x16xf32>)
+    outs(%output : tensor<1x12x16xf32>) {
+  ^bb0(%in: f32, %fil: f32, %out: f32):
+    %mul = arith.mulf %in, %fil : f32
+    %add = arith.addf %out, %mul : f32
+    linalg.yield %add : f32
+  } -> tensor<1x12x16xf32>
+  // CHECK:       return %[[RES]]
+  return %0 : tensor<1x12x16xf32>
+}
+
+#map_dw_i = affine_map<(oh, ow, c, kh, kw) -> (oh + kh, ow + kw, c)>
+#map_dw_f = affine_map<(oh, ow, c, kh, kw) -> (kh, kw, c)>
+#map_dw_o = affine_map<(oh, ow, c, kh, kw) -> (oh, ow, c)>
+
+// CHECK-LABEL: @batchless_depthwise_conv_2d_hwc_hwc
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<1x14x8xf32>
+// CHECK-SAME:    %[[ARG1:.+]]: tensor<1x3x8xf32>
+// CHECK-SAME:    %[[ARG2:.+]]: tensor<1x12x8xf32>
+func.func @batchless_depthwise_conv_2d_hwc_hwc(%input: tensor<1x14x8xf32>, %filter: tensor<1x3x8xf32>, %output: tensor<1x12x8xf32>) -> tensor<1x12x8xf32> {
+  // CHECK:       %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK:       %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK:       %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK:       %[[SLICERES:.+]] = linalg.generic
+  // CHECK-SAME:    indexing_maps = [#[[$DW_I]], #[[$DW_F]], #[[$MAP1]]]
+  // CHECK-SAME:    iterator_types = ["parallel", "parallel", "reduction"]
+  // CHECK:       %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  %0 = linalg.generic {
+    indexing_maps = [#map_dw_i, #map_dw_f, #map_dw_o],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
+  } ins(%input, %filter : tensor<1x14x8xf32>, tensor<1x3x8xf32>)
+    outs(%output : tensor<1x12x8xf32>) {
+  ^bb0(%in: f32, %fil: f32, %out: f32):
+    %mul = arith.mulf %in, %fil : f32
+    %add = arith.addf %out, %mul : f32
+    linalg.yield %add : f32
+  } -> tensor<1x12x8xf32>
+  // CHECK:       return %[[RES]]
+  return %0 : tensor<1x12x8xf32>
+}
+
+#map_pool_i = affine_map<(c, oh, ow, kh, kw) -> (c, oh + kh, ow + kw)>
+#map_pool_f = affine_map<(c, oh, ow, kh, kw) -> (kh, kw)>
+#map_pool_o = affine_map<(c, oh, ow, kh, kw) -> (c, oh, ow)>
+
+// CHECK-LABEL: @batchless_pooling_chw_min
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<8x1x14xf32>
+// CHECK-SAME:    %[[ARG1:.+]]: tensor<1x3xf32>
+// CHECK-SAME:    %[[ARG2:.+]]: tensor<8x1x12xf32>
+func.func @batchless_pooling_chw_min(%input: tensor<8x1x14xf32>, %filter: tensor<1x3xf32>, %output: tensor<8x1x12xf32>) -> tensor<8x1x12xf32> {
+  // CHECK:       %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK:       %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK:       %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK:       %[[SLICERES:.+]] = linalg.generic
+  // CHECK-SAME:    indexing_maps = [#[[$POOL_I]], #[[$POOL_F]], #[[$MAP1]]]
+  // CHECK-SAME:    iterator_types = ["parallel", "parallel", "reduction"]
+  // CHECK:       %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  %0 = linalg.generic {
+    indexing_maps = [#map_pool_i, #map_pool_f, #map_pool_o],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
+  } ins(%input, %filter : tensor<8x1x14xf32>, tensor<1x3xf32>)
+    outs(%output : tensor<8x1x12xf32>) {
+  ^bb0(%in: f32, %fil: f32, %out: f32):
+    %min = arith.minimumf %out, %in : f32
+    linalg.yield %min : f32
+  } -> tensor<8x1x12xf32>
+  // CHECK:       return %[[RES]]
+  return %0 : tensor<8x1x12xf32>
+}
+
 func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
   %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
   return %1 : tensor<2x16x32xf32>
@@ -234,19 +336,39 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
 
 // CHECK-LABEL:      func.func @softmax(
 // CHECK-SAME:           %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
-// CHECK:        linalg.fill
-// CHECK:        linalg.generic
-// CHECK:          arith.maxnumf
-// CHECK:        linalg.broadcast
-// CHECK:        linalg.generic
-// CHECK:          arith.subf
-// CHECK:          math.exp
-// CHECK:        linalg.fill
-// CHECK:        linalg.generic
-// CHECK:          arith.addf
-// CHECK:        linalg.broadcast
-// CHECK:        linalg.div
-// CHECK:        return
+// CHECK-DAG:        %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0xFFC00000 : f32
+// CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
+// CHECK:        %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME:     "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8:.+]] = arith.maxnumf %[[IN]], %[[OUT]] : f32
+// CHECK:          linalg.yield %[[D8]] : f32
+// CHECK:        } -> tensor<2x16xf32>
+// CHECK:        %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
+// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
+// CHECK-SAME:     outs(%[[DST]] : tensor<2x16x32xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32
+// CHECK:          %[[D9:.+]] = math.exp %[[D8]] : f32
+// CHECK:          linalg.yield %[[D9]] : f32
+// CHECK:        } -> tensor<2x16x32xf32>
+// CHECK:        %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:        %[[D5:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
+// CHECK:        %[[D6:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME:     "parallel", "reduction"]} ins(%[[D4]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8]] = arith.addf %[[IN]], %[[OUT]] : f32
+// CHECK:          linalg.yield %[[D8]] : f32
+// CHECK:        } -> tensor<2x16xf32>
+// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
+// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
+// CHECK-SAME:     outs(%[[DST]] : tensor<2x16x32xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32
+// CHECK:          linalg.yield %[[D8]] : f32
+// CHECK:        } -> tensor<2x16x32xf32>
+// CHECK:        return %[[D7]] : tensor<2x16x32xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
diff --git a/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir b/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir
index dfd502576d290..4660cc75a1940 100644
--- a/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --transform-interpreter --linalg-specialize-generic-ops --split-input-file -resolve-shaped-type-result-dims -canonicalize | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter --split-input-file -resolve-shaped-type-result-dims -canonicalize | FileCheck %s
 
 // Demonstrates what happens when peeling the 4th loop (that corresponds to the
 // "depth" dimension in depthwise convs) followed by vectorization in the
@@ -73,9 +73,7 @@ module attributes {transform.with_named_sequence} {
 
     // 4. Apply loop peeling - only the 4th loop
     %main_loop, %remainder_loop = transform.loop.peel %loops_1#3 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">, !transform.op<"scf.for">)
-    // Match linalg.generic since decompose produces generic ops, and
-    // --linalg-specialize-generic-ops runs after the transform interpreter.
-    %5 = transform.structured.match ops{["linalg.generic"]} in %main_loop : (!transform.op<"scf.for">) -> !transform.any_op
+    %5 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %main_loop : (!transform.op<"scf.for">) -> !transform.any_op
 
     // 5. Vectorize, but only the main loop
     transform.structured.vectorize %5 vector_sizes [2, 4, [4], 16] : !transform.any_op

>From d045043b729d8f422aa2206e6122cf95a75b929b Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 11 Feb 2026 09:42:20 +0000
Subject: [PATCH 3/5] Move wrapper to pass + other review comments

---
 .../Dialect/Linalg/Transforms/Transforms.h    | 13 ----
 .../TransformOps/LinalgTransformOps.cpp       |  2 +-
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 64 +++++++++++--------
 .../Linalg/transform-op-decompose.mlir        | 28 ++++++++
 4 files changed, 66 insertions(+), 41 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6f1f8be0d71d9..9d93cf3488e0c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1647,19 +1647,6 @@ FailureOr<linalg::GenericOp> deduplicateOperandsAndRemoveDeadResults(
 FailureOr<LinalgOp> downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
                                                         LinalgOp op);
 
-/// Pattern wrapper around `downscaleSizeOneWindowedConvolution`.
-struct DownscaleSizeOneWindowedConvolution final
-    : public OpInterfaceRewritePattern<LinalgOp> {
-  DownscaleSizeOneWindowedConvolution(MLIRContext *context,
-                                      PatternBenefit benefit = 1)
-      : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
-
-  LogicalResult matchAndRewrite(LinalgOp op,
-                                PatternRewriter &rewriter) const override {
-    return downscaleSizeOneWindowedConvolution(rewriter, op);
-  }
-};
-
 ///
 /// Linalg generalization pattern.
 ///
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 332f9f02c3e3c..12c52c6dabd91 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -484,7 +484,7 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
                                    LinalgOp target,
                                    transform::ApplyToEachResultList &results,
                                    transform::TransformState &state) {
-  FailureOr<linalg::GenericOp> res =
+  FailureOr<linalg::LinalgOp> res =
       downscaleSizeOneWindowedConvolution(rewriter, target);
   if (succeeded(res)) {
     results.push_back(*res);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index cad1f75d4d344..f5340e29ffce7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1426,11 +1426,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
 // Generic DownscaleSizeOneWindowedConvolution
 //===----------------------------------------------------------------------===//
 //
-// This pattern rewrites 2-D convolution/pooling/depthwise ops with size-1
-// window dimensions into lower-dimensional ops. It uses inferConvolutionDims
-// to work with any layout and handles both named ops and equivalent
-// linalg.generic ops uniformly.
-//
 /// Returns the indices of affine map results that reference any of the given
 /// dimensions.
 static SmallVector<unsigned>
@@ -1509,43 +1504,43 @@ linalg::downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
   if (failed(maybeDims))
     return failure();
 
-  // Must be 2D Conv.
+  // Requires exactly 2 spatial dimensions to downscale to 1D.
   if (maybeDims->outputImage.size() != 2 || maybeDims->filterLoop.size() != 2)
     return failure();
 
   if (op.hasPureBufferSemantics())
     return failure();
 
-  // Get loop domain indices.
-  unsigned ohLoopIdx = maybeDims->outputImage[0];
-  unsigned owLoopIdx = maybeDims->outputImage[1];
-  unsigned khLoopIdx = maybeDims->filterLoop[0];
-  unsigned kwLoopIdx = maybeDims->filterLoop[1];
+  // Get loop domain indices for spatial dimensions.
+  unsigned outSpatial0 = maybeDims->outputImage[0];
+  unsigned outSpatial1 = maybeDims->outputImage[1];
+  unsigned filterSpatial0 = maybeDims->filterLoop[0];
+  unsigned filterSpatial1 = maybeDims->filterLoop[1];
 
   // Get sizes from loop bounds.
   SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
-  int64_t ohSize = loopRanges[ohLoopIdx];
-  int64_t owSize = loopRanges[owLoopIdx];
-  int64_t khSize = loopRanges[khLoopIdx];
-  int64_t kwSize = loopRanges[kwLoopIdx];
-
-  // Check if we can downscale.
-  bool canRemoveH = (khSize == 1 && ohSize == 1);
-  bool canRemoveW = (kwSize == 1 && owSize == 1);
-  if (!canRemoveH && !canRemoveW)
+  int64_t outSize0 = loopRanges[outSpatial0];
+  int64_t outSize1 = loopRanges[outSpatial1];
+  int64_t filterSize0 = loopRanges[filterSpatial0];
+  int64_t filterSize1 = loopRanges[filterSpatial1];
+
+  // Check if we can downscale by removing a spatial dimension.
+  bool canRemoveSpatial0 = (filterSize0 == 1 && outSize0 == 1);
+  bool canRemoveSpatial1 = (filterSize1 == 1 && outSize1 == 1);
+  if (!canRemoveSpatial0 && !canRemoveSpatial1)
     return failure();
 
-  // Prefer removing H if both are possible.
-  bool removeH = canRemoveH;
+  // Prioritize dropping the leading spatial dimension if both are removable.
+  bool removeSpatial0 = canRemoveSpatial0;
 
   // Determine which loop dims to remove (output spatial + corresponding filter)
   SmallVector<unsigned> loopDimsToRemove;
-  if (removeH) {
-    loopDimsToRemove.push_back(ohLoopIdx);
-    loopDimsToRemove.push_back(khLoopIdx);
+  if (removeSpatial0) {
+    loopDimsToRemove.push_back(outSpatial0);
+    loopDimsToRemove.push_back(filterSpatial0);
   } else {
-    loopDimsToRemove.push_back(owLoopIdx);
-    loopDimsToRemove.push_back(kwLoopIdx);
+    loopDimsToRemove.push_back(outSpatial1);
+    loopDimsToRemove.push_back(filterSpatial1);
   }
   // Sort for correct index compression when removing dimensions from affine
   // maps.
@@ -1633,6 +1628,21 @@ linalg::downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
   return resultOp;
 }
 
+namespace {
+/// Pattern wrapper around `downscaleSizeOneWindowedConvolution`.
+struct DownscaleSizeOneWindowedConvolution final
+    : public OpInterfaceRewritePattern<LinalgOp> {
+  DownscaleSizeOneWindowedConvolution(MLIRContext *context,
+                                      PatternBenefit benefit = 1)
+      : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
+
+  LogicalResult matchAndRewrite(LinalgOp op,
+                                PatternRewriter &rewriter) const override {
+    return linalg::downscaleSizeOneWindowedConvolution(rewriter, op);
+  }
+};
+} // namespace
+
 void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
                                                   PatternBenefit benefit) {
   patterns.add<DownscaleSizeOneWindowedConvolution>(patterns.getContext(),
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 0f4a6f4c05d89..8a82c838670d5 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -329,6 +329,34 @@ func.func @batchless_pooling_chw_min(%input: tensor<8x1x14xf32>, %filter: tensor
   return %0 : tensor<8x1x12xf32>
 }
 
+#map_cross_i = affine_map<(d0, d1, d2, d3) -> (d0 + d3, d1 + d2)>
+#map_cross_f = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map_cross_o = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+
+// CHECK-LABEL: @cross_conv_nonstandard_loop_order
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<1x15xf32>
+// CHECK-SAME:    %[[ARG1:.+]]: tensor<3x1xf32>
+// CHECK-SAME:    %[[ARG2:.+]]: tensor<1x12xf32>
+func.func @cross_conv_nonstandard_loop_order(%input: tensor<1x15xf32>, %filter: tensor<3x1xf32>, %output: tensor<1x12xf32>) -> tensor<1x12xf32> {
+  // CHECK:       %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK:       %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK:       %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK:       %[[SLICERES:.+]] = linalg.conv_1d
+  // CHECK:       %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK:       return %[[RES]]
+  %0 = linalg.generic {
+    indexing_maps = [#map_cross_i, #map_cross_f, #map_cross_o],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]
+  } ins(%input, %filter : tensor<1x15xf32>, tensor<3x1xf32>)
+    outs(%output : tensor<1x12xf32>) {
+  ^bb0(%in: f32, %fil: f32, %out: f32):
+    %mul = arith.mulf %in, %fil : f32
+    %add = arith.addf %out, %mul : f32
+    linalg.yield %add : f32
+  } -> tensor<1x12xf32>
+  return %0 : tensor<1x12xf32>
+}
+
 func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
   %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
   return %1 : tensor<2x16x32xf32>

>From 699fc9336adb8c050ed2d6d21a34d141c3bbe68b Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 12 Feb 2026 06:52:28 +0000
Subject: [PATCH 4/5] Few more nits

---
 .../Dialect/Linalg/TransformOps/LinalgTransformOps.td  | 10 +++++++---
 .../mlir/Dialect/Linalg/Transforms/Transforms.h        |  3 ++-
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp      |  2 +-
 3 files changed, 10 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 70d424bae9285..d4d9d56110f19 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -288,9 +288,13 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
      TransformEachOpTrait,
      ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
-    Decomposes named complex operations, such as higher-dimensional
-    (depthwise) convolutions, into combinations of lower-dimensional equivalents
-    when possible.
+    Decomposes higher-dimensional convolution ops into lower-dimensional
+    equivalents when possible. This operates on both named ops and equivalent
+    `linalg.generic` ops that have convolution-like structure (as determined
+    by `inferConvolutionDims`).
+
+    The transformation always attempts to specialize the result back to a named
+    op when possible.
 
     #### Return modes
 
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 9d93cf3488e0c..3f90992133787 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1640,10 +1640,11 @@ FailureOr<linalg::GenericOp> deduplicateOperandsAndRemoveDeadResults(
 // functional-stye API call.
 //===----------------------------------------------------------------------===//
 
-/// Rewrite 2-D convolution/pooling/depthwise ops with size-1 window dimensions
+/// Rewrite convolution/pooling/depthwise ops with size-1 window dimensions
 /// into lower-dimensional ops. Uses `inferConvolutionDims` to work with any
 /// layout and handles both named ops and equivalent linalg.generic ops
 /// uniformly. The result is specialized back to a named op when possible.
+/// TODO: Support n-D to (n-1)-D downscaling.
 FailureOr<LinalgOp> downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
                                                         LinalgOp op);
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index f5340e29ffce7..87d5fc5810e49 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1504,7 +1504,7 @@ linalg::downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
   if (failed(maybeDims))
     return failure();
 
-  // Requires exactly 2 spatial dimensions to downscale to 1D.
+  // Currently supports only 2D convolutions.
   if (maybeDims->outputImage.size() != 2 || maybeDims->filterLoop.size() != 2)
     return failure();
 

>From 3cfbc306fcd436e682d3ad1c1d8086a311d7d17e Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Mon, 23 Feb 2026 14:26:52 +0000
Subject: [PATCH 5/5] Guard specializing + udpate lit test minimally

---
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp    | 11 +++++++----
 mlir/test/Dialect/Linalg/transform-op-decompose.mlir | 12 ++++++++----
 2 files changed, 15 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 87d5fc5810e49..63e831adf99a7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1614,11 +1614,14 @@ linalg::downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
                                 mapping.lookup(yield.getOperand(0)));
       });
 
-  // Try to specialize the generic back to a named op if possible.
+  // Try to specialize the generic back to a named op only if the input was
+  // already a specialized (named) op.
   LinalgOp resultOp = newOp;
-  FailureOr<LinalgOp> specializedOp = specializeGenericOp(rewriter, newOp);
-  if (succeeded(specializedOp))
-    resultOp = *specializedOp;
+  if (!isa<GenericOp>(op)) {
+    FailureOr<LinalgOp> specializedOp = specializeGenericOp(rewriter, newOp);
+    if (succeeded(specializedOp))
+      resultOp = *specializedOp;
+  }
 
   // Insert result back into original shape.
   Value result = tensor::createCanonicalRankReducingInsertSliceOp(
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 8a82c838670d5..3897f8502bb04 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -1,7 +1,4 @@
 // RUN: mlir-opt --transform-interpreter --split-input-file %s | FileCheck %s
-// Test the same patterns on generic convolution ops by first generalizing the
-// named ops. This avoids duplicating lit tests for linalg.generic conv ops.
-// RUN: mlir-opt --linalg-generalize-named-ops --transform-interpreter --split-input-file %s | FileCheck %s
 
 // Expected indexing maps for batchless conv_1d_nwc_wcf.
 // CHECK-DAG:  #[[$CONV_I:.+]] = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d3)>
@@ -19,6 +16,11 @@
 // CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
+// Expected indexing maps for 1D conv (cross-conv after downscale from generic).
+// CHECK-DAG:  #[[$CROSS_1D_I:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-DAG:  #[[$CROSS_1D_F:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-DAG:  #[[$CROSS_1D_O:.+]] = affine_map<(d0, d1) -> (d0)>
+
 // CHECK-LABEL: @conv_2d_nhwc_hwcf
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
@@ -341,7 +343,9 @@ func.func @cross_conv_nonstandard_loop_order(%input: tensor<1x15xf32>, %filter:
   // CHECK:       %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
   // CHECK:       %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
   // CHECK:       %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK:       %[[SLICERES:.+]] = linalg.conv_1d
+  // CHECK:       %[[SLICERES:.+]] = linalg.generic
+  // CHECK-SAME:    indexing_maps = [#[[$CROSS_1D_I]], #[[$CROSS_1D_F]], #[[$CROSS_1D_O]]]
+  // CHECK-SAME:    iterator_types = ["parallel", "reduction"]
   // CHECK:       %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   // CHECK:       return %[[RES]]
   %0 = linalg.generic {



More information about the Mlir-commits mailing list