[Mlir-commits] [mlir] [Linalg] Update Conv Decomposition patterns to work with generic convolution ops as well (PR #174196)

Abhishek Varma llvmlistbot at llvm.org
Tue Jan 13 22:55:48 PST 2026


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/174196

>From 7c1296fc254ad2773bacdfa66107880c603f76f1 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 1 Jan 2026 07:36:31 +0000
Subject: [PATCH 1/2] [Linalg] Update Conv Decomposition to work with generic
 conv ops

-- This commit updates Conv Decomposition to work with both named as
   well as generic convolution ops.
-- This required an update to the `isaConvolutionOfType` API to also
   populate dilations/strides info for named convolution ops and since
   now a generic LinalgOp is being used as the root op in the pattern
   above the assert of the op implementing a ConvolutionOpInterface has
   been replaced with an early exit if.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../Dialect/Linalg/Transforms/Transforms.h    |  27 +-
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 107 ++++----
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 188 +++++++-------
 .../Linalg/transform-op-decompose.mlir        | 238 ++++++++++++------
 4 files changed, 322 insertions(+), 238 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6678d693719bf..32067358438d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1641,16 +1641,16 @@ FailureOr<linalg::GenericOp> deduplicateOperandsAndRemoveDeadResults(
 //===----------------------------------------------------------------------===//
 
 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
-/// convolution ops.
+/// convolution ops. Works with both named ops and equivalent generic ops.
 template <typename Conv2DOp, typename Conv1DOp>
 struct DownscaleSizeOneWindowed2DConvolution final
-    : public OpRewritePattern<Conv2DOp> {
-  using OpRewritePattern<Conv2DOp>::OpRewritePattern;
+    : public OpInterfaceRewritePattern<LinalgOp> {
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
 
-  FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
+  FailureOr<Conv1DOp> returningMatchAndRewrite(LinalgOp convOp,
                                                PatternRewriter &rewriter) const;
 
-  LogicalResult matchAndRewrite(Conv2DOp convOp,
+  LogicalResult matchAndRewrite(LinalgOp convOp,
                                 PatternRewriter &rewriter) const override {
     return returningMatchAndRewrite(convOp, rewriter);
   }
@@ -1664,29 +1664,28 @@ extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
 /// dimensions into 1-D depthwise convolution ops.
 struct DownscaleDepthwiseConv2DNhwcHwcOp final
-    : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
+    : public OpInterfaceRewritePattern<LinalgOp> {
   DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context,
                                     PatternBenefit benefit = 1)
-      : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit) {}
+      : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
 
   FailureOr<DepthwiseConv1DNwcWcOp>
-  returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
-                           PatternRewriter &rewriter) const;
+  returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const;
 
-  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
+  LogicalResult matchAndRewrite(LinalgOp convOp,
                                 PatternRewriter &rewriter) const override {
     return returningMatchAndRewrite(convOp, rewriter);
   }
 };
 
-struct DownscaleConv2DOp final : public OpRewritePattern<Conv2DOp> {
+struct DownscaleConv2DOp final : public OpInterfaceRewritePattern<LinalgOp> {
   DownscaleConv2DOp(MLIRContext *context, PatternBenefit benefit = 1)
-      : OpRewritePattern<Conv2DOp>(context, benefit) {}
+      : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
 
-  FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
+  FailureOr<Conv1DOp> returningMatchAndRewrite(LinalgOp convOp,
                                                PatternRewriter &rewriter) const;
 
-  LogicalResult matchAndRewrite(Conv2DOp convOp,
+  LogicalResult matchAndRewrite(LinalgOp convOp,
                                 PatternRewriter &rewriter) const override {
     return returningMatchAndRewrite(convOp, rewriter);
   }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 96cc378f6c21a..7972408318b95 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -32,6 +32,7 @@
 #include "llvm/Support/DebugLog.h"
 #include "llvm/Support/InterleavedRange.h"
 #include "llvm/Support/raw_ostream.h"
+#include <type_traits>
 #include <utility>
 
 #define DEBUG_TYPE "linalg-transforms"
@@ -1406,13 +1407,18 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
 
 template <typename Conv2DOp, typename Conv1DOp>
 FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
-    returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
+    returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const {
+  // Check if this LinalgOp is of the expected Conv2DOp type (named or generic).
+  SmallVector<int64_t> dilations, strides;
+  if (!linalg::isaConvolutionOpOfType<Conv2DOp>(convOp, &dilations, &strides))
+    return failure();
+
   if (convOp.hasPureBufferSemantics())
     return failure(); // To be implemented.
 
-  Value input = convOp.getInputs().front();
-  Value kernel = convOp.getInputs().back();
-  Value output = convOp.getOutputs().front();
+  Value input = convOp.getDpsInputs().front();
+  Value kernel = convOp.getDpsInputs().back();
+  Value output = convOp.getDpsInits().front();
 
   auto inputType = dyn_cast<RankedTensorType>(input.getType());
   auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
@@ -1421,38 +1427,33 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
   auto kernelShape = kernelType.getShape();
   auto outputShape = outputType.getShape();
 
-  // Get domain indices based on conv2D layout.
-  auto [khIndex, kwIndex, ohIndex, owIndex] =
-      TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>(
-          convOp)
-          .Case([&](linalg::Conv2DNhwcHwcfOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::Conv2DNchwFchwOp op) {
-            return std::make_tuple(2, 3, 2, 3);
-          })
-          .Case([&](linalg::PoolingNhwcSumOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::PoolingNchwSumOp op) {
-            return std::make_tuple(0, 1, 2, 3);
-          })
-          .Case([&](linalg::PoolingNhwcMaxOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::PoolingNhwcMinOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::PoolingNchwMaxOp op) {
-            return std::make_tuple(0, 1, 2, 3);
-          })
-          .DefaultUnreachable("unexpected conv2d/pool2d operation.");
+  // Get domain indices based on Conv2DOp type. These are known at compile time.
+  int64_t khIndex, kwIndex, ohIndex, owIndex;
+  if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNhwcHwcfOp> ||
+                std::is_same_v<Conv2DOp, linalg::PoolingNhwcSumOp> ||
+                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxOp> ||
+                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxUnsignedOp> ||
+                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinOp> ||
+                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinUnsignedOp>) {
+    // NHWC layout: kernel [H, W, ...], output [N, H, W, C]
+    khIndex = 0;
+    kwIndex = 1;
+    ohIndex = 1;
+    owIndex = 2;
+  } else if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNchwFchwOp>) {
+    // NCHW_FCHW layout: kernel [..., H, W], output [N, C, H, W]
+    khIndex = 2;
+    kwIndex = 3;
+    ohIndex = 2;
+    owIndex = 3;
+  } else if constexpr (std::is_same_v<Conv2DOp, linalg::PoolingNchwSumOp> ||
+                       std::is_same_v<Conv2DOp, linalg::PoolingNchwMaxOp>) {
+    // NCHW pooling layout: kernel [H, W], output [N, C, H, W]
+    khIndex = 0;
+    kwIndex = 1;
+    ohIndex = 2;
+    owIndex = 3;
+  }
 
   // Only handle the case where at least one of the window dimensions is
   // of size 1. Other cases can rely on tiling to reduce to such cases.
@@ -1484,13 +1485,9 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
 
   // Rank-reduce strides and dilations too.
   // TODO: dropDim 1-liner helper.
-  auto strides =
-      llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
   strides.erase(strides.begin() + (removeH ? 0 : 1));
   auto stridesAttr = rewriter.getI64VectorAttr(strides);
 
-  auto dilations =
-      llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
 
@@ -1527,13 +1524,19 @@ template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
 
 FailureOr<DepthwiseConv1DNwcWcOp>
 DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
-    DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
+    LinalgOp convOp, PatternRewriter &rewriter) const {
+  // Check if this LinalgOp is a DepthwiseConv2DNhwcHwcOp (named or generic).
+  SmallVector<int64_t> dilations, strides;
+  if (!linalg::isaConvolutionOpOfType<DepthwiseConv2DNhwcHwcOp>(
+          convOp, &dilations, &strides))
+    return failure();
+
   if (convOp.hasPureBufferSemantics())
     return failure(); // To be implemented.
 
-  Value input = convOp.getInputs().front();
-  Value kernel = convOp.getInputs().back();
-  Value output = convOp.getOutputs().front();
+  Value input = convOp.getDpsInputs().front();
+  Value kernel = convOp.getDpsInputs().back();
+  Value output = convOp.getDpsInits().front();
 
   auto inputType = dyn_cast<RankedTensorType>(input.getType());
   auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
@@ -1572,12 +1575,9 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
 
   // Rank-reduce strides and dilations too.
   // TODO: dropDim 1-liner helper.
-  auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
   strides.erase(strides.begin() + (removeH ? 0 : 1));
   auto stridesAttr = rewriter.getI64VectorAttr(strides);
 
-  auto dilations =
-      llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
 
@@ -1594,14 +1594,19 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
 }
 
 FailureOr<Conv1DOp>
-DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
+DownscaleConv2DOp::returningMatchAndRewrite(LinalgOp convOp,
                                             PatternRewriter &rewriter) const {
+  // Check if this LinalgOp is a Conv2DOp (named or generic).
+  SmallVector<int64_t> dilations, strides;
+  if (!linalg::isaConvolutionOpOfType<Conv2DOp>(convOp, &dilations, &strides))
+    return failure();
+
   if (convOp.hasPureBufferSemantics())
     return failure(); // To be implemented.
 
-  Value input = convOp.getInputs().front();
-  Value kernel = convOp.getInputs().back();
-  Value output = convOp.getOutputs().front();
+  Value input = convOp.getDpsInputs().front();
+  Value kernel = convOp.getDpsInputs().back();
+  Value output = convOp.getDpsInits().front();
 
   auto inputType = dyn_cast<RankedTensorType>(input.getType());
   auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 4e1731ddcfd45..daf02442bb21a 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -602,8 +602,8 @@ matchConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
                        &result.strides);
@@ -630,8 +630,8 @@ matchConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
                        &result.strides);
@@ -661,8 +661,8 @@ matchConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
                        &result.strides);
@@ -692,8 +692,8 @@ matchConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -723,8 +723,8 @@ matchConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -758,8 +758,8 @@ matchConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -795,8 +795,8 @@ matchConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -830,8 +830,8 @@ matchConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -867,8 +867,8 @@ matchConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -902,8 +902,8 @@ matchConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -939,8 +939,8 @@ matchConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -975,8 +975,8 @@ matchConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -1011,8 +1011,8 @@ matchConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -1049,8 +1049,8 @@ matchConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -1085,8 +1085,8 @@ matchConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -1123,8 +1123,8 @@ matchConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
                        &result.strides);
@@ -1158,8 +1158,8 @@ matchConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
                        &result.strides);
@@ -1196,8 +1196,8 @@ matchConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfQOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
                        &result.strides);
@@ -1236,8 +1236,8 @@ matchConvolutionOpOfType<linalg::Conv3DNcdhwFcdhwOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
                        &result.strides);
@@ -1275,8 +1275,8 @@ matchConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
                        &result.strides);
@@ -1306,8 +1306,8 @@ matchConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
                        &result.strides);
@@ -1337,8 +1337,8 @@ matchConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
                        &result.strides);
@@ -1369,8 +1369,8 @@ matchConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -1404,8 +1404,8 @@ matchConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -1439,8 +1439,8 @@ matchConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -1476,8 +1476,8 @@ matchConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -1512,8 +1512,8 @@ matchConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmQOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides);
@@ -1550,8 +1550,8 @@ matchConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
                        &result.strides);
@@ -1588,8 +1588,8 @@ matchConvolutionOpOfType<linalg::DepthwiseConv3DNcdhwCdhwOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
                        &result.strides);
@@ -1626,8 +1626,8 @@ matchConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
                        &result.strides);
@@ -1664,8 +1664,8 @@ matchConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides, PoolingType::MaxSigned);
@@ -1698,8 +1698,8 @@ matchConvolutionOpOfType<linalg::PoolingNhwcMinOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides, PoolingType::MinSigned);
@@ -1732,8 +1732,8 @@ matchConvolutionOpOfType<linalg::PoolingNhwcSumOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides, PoolingType::Sum);
@@ -1767,8 +1767,8 @@ matchConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides, PoolingType::MaxUnsigned);
@@ -1802,8 +1802,8 @@ matchConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides, PoolingType::MinUnsigned);
@@ -1836,8 +1836,8 @@ matchConvolutionOpOfType<linalg::PoolingNchwSumOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides, PoolingType::Sum);
@@ -1870,8 +1870,8 @@ matchConvolutionOpOfType<linalg::PoolingNchwMaxOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
                        &result.strides, PoolingType::MaxSigned);
@@ -1904,8 +1904,8 @@ matchConvolutionOpOfType<linalg::PoolingNwcSumOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
                        &result.strides, PoolingType::Sum);
@@ -1934,8 +1934,8 @@ matchConvolutionOpOfType<linalg::PoolingNcwSumOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
                        &result.strides, PoolingType::Sum);
@@ -1964,8 +1964,8 @@ matchConvolutionOpOfType<linalg::PoolingNwcMaxOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
                        &result.strides, PoolingType::MaxSigned);
@@ -1995,8 +1995,8 @@ matchConvolutionOpOfType<linalg::PoolingNwcMaxUnsignedOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
                        &result.strides, PoolingType::MaxUnsigned);
@@ -2025,8 +2025,8 @@ matchConvolutionOpOfType<linalg::PoolingNcwMaxOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
                        &result.strides, PoolingType::MaxSigned);
@@ -2055,8 +2055,8 @@ matchConvolutionOpOfType<linalg::PoolingNwcMinOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
                        &result.strides, PoolingType::MinSigned);
@@ -2086,8 +2086,8 @@ matchConvolutionOpOfType<linalg::PoolingNwcMinUnsignedOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
                        &result.strides, PoolingType::MinUnsigned);
@@ -2116,8 +2116,8 @@ matchConvolutionOpOfType<linalg::PoolingNdhwcSumOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
                        &result.strides, PoolingType::Sum);
@@ -2153,8 +2153,8 @@ matchConvolutionOpOfType<linalg::PoolingNdhwcMaxOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
                        &result.strides, PoolingType::MaxSigned);
@@ -2190,8 +2190,8 @@ matchConvolutionOpOfType<linalg::PoolingNdhwcMinOp>(LinalgOp op) {
     return result;
   }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return std::nullopt;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
                        &result.strides, PoolingType::MinSigned);
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 60a4c555fa19a..7798cb76e4fb9 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -3,113 +3,168 @@
 // CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
+#map_nhwc_hwcf_input = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+#map_nhwc_hwcf_filter = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+#map_nhwc_hwcf_output = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+
 // CHECK-LABEL: @conv_2d_nhwc_hwcf
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
 func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_nwc_wcf
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to conv_1d_nwc_wcf
+  // CHECK-COUNT-2: linalg.conv_1d_nwc_wcf
   %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
                                  strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>)
     outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_nhwc_hwcf_input, #map_nhwc_hwcf_filter, #map_nhwc_hwcf_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.mulf %in, %in_0 : f32
+    %3 = arith.addf %out, %2 : f32
+    linalg.yield %3 : f32
+  } -> tensor<?x1x?x?xf32>
+  return %1 : tensor<?x1x?x?xf32>
 }
 
+#map_nchw_fchw_input = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+#map_nchw_fchw_filter = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+#map_nchw_fchw_output = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+
 // CHECK-LABEL: @conv_2d_nchw_fchw
 // CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
 // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
 // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
 func.func @conv_2d_nchw_fchw(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x?x1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_ncw_fcw
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to conv_1d_ncw_fcw
+  // CHECK-COUNT-2: linalg.conv_1d_ncw_fcw
   %0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>,
                                  strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>)
     outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x?x1x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_nchw_fchw_input, #map_nchw_fchw_filter, #map_nchw_fchw_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>) outs(%0 : tensor<?x?x1x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.mulf %in, %in_0 : f32
+    %3 = arith.addf %out, %2 : f32
+    linalg.yield %3 : f32
+  } -> tensor<?x?x1x?xf32>
+  return %1 : tensor<?x?x1x?xf32>
 }
 
+#map_depthwise_nhwc_hwc_input = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d3)>
+#map_depthwise_nhwc_hwc_filter = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
+#map_depthwise_nhwc_hwc_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
 // CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
 // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>
 func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> {
   // CHECK: %[[RES:.+]] = tensor.empty
   %init = tensor.empty() : tensor<1x1x56x96xf32>
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICERES:.+]] = tensor.extract_slice %[[RES]]
-  // CHECK: %[[OPRES:.+]] = linalg.depthwise_conv_1d_nwc_wc
-  // CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]]
-  // CHECK-SAME: outs(%[[SLICERES]]
-  // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[OPRES]] into %[[RES]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // Both named and generic ops should decompose to depthwise_conv_1d_nwc_wc
+  // CHECK-COUNT-2: linalg.depthwise_conv_1d_nwc_wc
   %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
          ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>)
          outs(%init: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32>
-  // CHECK: %[[INSERTED]]
-  return %0: tensor<1x1x56x96xf32>
+  // Generic op version with same semantics (strides = 2).
+  %1 = linalg.generic {indexing_maps = [#map_depthwise_nhwc_hwc_input, #map_depthwise_nhwc_hwc_filter, #map_depthwise_nhwc_hwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<1x1x113x96xf32>, tensor<1x3x96xf32>) outs(%0 : tensor<1x1x56x96xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.mulf %in, %in_0 : f32
+    %3 = arith.addf %out, %2 : f32
+    linalg.yield %3 : f32
+  } -> tensor<1x1x56x96xf32>
+  return %1: tensor<1x1x56x96xf32>
 }
 
+#map_conv_2d_input = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
+#map_conv_2d_filter = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map_conv_2d_output = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+
 // CHECK-LABEL: @conv_2d
 // CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<1x?xf32>,
 // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
 // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<1x?xf32>)
 func.func @conv_2d(%input: tensor<1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<1x?xf32>) -> tensor<1x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.conv_1d
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to conv_1d
+  // CHECK-COUNT-2: linalg.conv_1d
   %0 = linalg.conv_2d
      ins (%input, %filter: tensor<1x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<1x?xf32>) -> tensor<1x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<1x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_conv_2d_input, #map_conv_2d_filter, #map_conv_2d_output], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<1x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<1x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.mulf %in, %in_0 : f32
+    %3 = arith.addf %out, %2 : f32
+    linalg.yield %3 : f32
+  } -> tensor<1x?xf32>
+  return %1 : tensor<1x?xf32>
 }
 
+#map_pooling_nhwc_input = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+#map_pooling_nhwc_filter = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+#map_pooling_nhwc_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
 // CHECK-LABEL: @pooling_nhwc_sum
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
 func.func @pooling_nhwc_sum(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_sum
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to pooling_nwc_sum
+  // CHECK-COUNT-2: linalg.pooling_nwc_sum
   %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.addf %out, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x1x?x?xf32>
+  return %1 : tensor<?x1x?x?xf32>
 }
 
+#map_pooling_nchw_input = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
+#map_pooling_nchw_filter = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+#map_pooling_nchw_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
 // CHECK-LABEL: @pooling_nchw_sum
 // CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
 // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
 // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
 func.func @pooling_nchw_sum(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_sum
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to pooling_ncw_sum
+  // CHECK-COUNT-2: linalg.pooling_ncw_sum
   %0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x?x1x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_pooling_nchw_input, #map_pooling_nchw_filter, #map_pooling_nchw_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x?x1x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.addf %out, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x1x?xf32>
+  return %1 : tensor<?x?x1x?xf32>
 }
 
 // CHECK-LABEL: @pooling_nhwc_max
@@ -117,17 +172,22 @@ func.func @pooling_nchw_sum(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
 func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to pooling_nwc_max
+  // CHECK-COUNT-2: linalg.pooling_nwc_max
   %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.maximumf %out, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x1x?x?xf32>
+  return %1 : tensor<?x1x?x?xf32>
 }
 
 // CHECK-LABEL: @pooling_nhwc_max_unsigned
@@ -135,17 +195,22 @@ func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
 func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max_unsigned
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to pooling_nwc_max_unsigned
+  // CHECK-COUNT-2: linalg.pooling_nwc_max_unsigned
   %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
     outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xi32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xi32>, tensor<1x?xi32>) outs(%0 : tensor<?x1x?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %2 = arith.maxui %out, %in : i32
+    linalg.yield %2 : i32
+  } -> tensor<?x1x?x?xi32>
+  return %1 : tensor<?x1x?x?xi32>
 }
 
 // CHECK-LABEL: @pooling_nhwc_min
@@ -153,17 +218,22 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tenso
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
 func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to pooling_nwc_min
+  // CHECK-COUNT-2: linalg.pooling_nwc_min
   %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.minimumf %out, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x1x?x?xf32>
+  return %1 : tensor<?x1x?x?xf32>
 }
 
 // CHECK-LABEL: @pooling_nhwc_min_unsigned
@@ -171,17 +241,22 @@ func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
 func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min_unsigned
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to pooling_nwc_min_unsigned
+  // CHECK-COUNT-2: linalg.pooling_nwc_min_unsigned
   %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
     outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xi32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xi32>, tensor<1x?xi32>) outs(%0 : tensor<?x1x?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %2 = arith.minui %out, %in : i32
+    linalg.yield %2 : i32
+  } -> tensor<?x1x?x?xi32>
+  return %1 : tensor<?x1x?x?xi32>
 }
 
 // CHECK-LABEL: @pooling_nchw_max
@@ -189,17 +264,22 @@ func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tenso
 // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
 // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
 func.func @pooling_nchw_max(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_max
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to pooling_ncw_max
+  // CHECK-COUNT-2: linalg.pooling_ncw_max
   %0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x?x1x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_pooling_nchw_input, #map_pooling_nchw_filter, #map_pooling_nchw_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x?x1x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.maximumf %out, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x1x?xf32>
+  return %1 : tensor<?x?x1x?xf32>
 }
 
 func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {

>From 6302ec5c03849625c2d12e63fd43058440800929 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 6 Jan 2026 08:10:41 +0000
Subject: [PATCH 2/2] Review comment by Hanhan v1.0 : Update API + different
 RUN line

---
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  20 +-
 .../Linalg/transform-op-decompose.mlir        | 239 ++++++------------
 2 files changed, 93 insertions(+), 166 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 7972408318b95..fc7cdad0ee33d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1409,9 +1409,12 @@ template <typename Conv2DOp, typename Conv1DOp>
 FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
     returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const {
   // Check if this LinalgOp is of the expected Conv2DOp type (named or generic).
-  SmallVector<int64_t> dilations, strides;
-  if (!linalg::isaConvolutionOpOfType<Conv2DOp>(convOp, &dilations, &strides))
+  std::optional<DilationsAndStrides> convParams =
+      matchConvolutionOpOfType<Conv2DOp>(convOp);
+  if (!convParams)
     return failure();
+  SmallVector<int64_t> dilations = std::move(convParams->dilations);
+  SmallVector<int64_t> strides = std::move(convParams->strides);
 
   if (convOp.hasPureBufferSemantics())
     return failure(); // To be implemented.
@@ -1526,10 +1529,12 @@ FailureOr<DepthwiseConv1DNwcWcOp>
 DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
     LinalgOp convOp, PatternRewriter &rewriter) const {
   // Check if this LinalgOp is a DepthwiseConv2DNhwcHwcOp (named or generic).
-  SmallVector<int64_t> dilations, strides;
-  if (!linalg::isaConvolutionOpOfType<DepthwiseConv2DNhwcHwcOp>(
-          convOp, &dilations, &strides))
+  std::optional<DilationsAndStrides> convParams =
+      matchConvolutionOpOfType<DepthwiseConv2DNhwcHwcOp>(convOp);
+  if (!convParams)
     return failure();
+  SmallVector<int64_t> dilations = std::move(convParams->dilations);
+  SmallVector<int64_t> strides = std::move(convParams->strides);
 
   if (convOp.hasPureBufferSemantics())
     return failure(); // To be implemented.
@@ -1597,8 +1602,9 @@ FailureOr<Conv1DOp>
 DownscaleConv2DOp::returningMatchAndRewrite(LinalgOp convOp,
                                             PatternRewriter &rewriter) const {
   // Check if this LinalgOp is a Conv2DOp (named or generic).
-  SmallVector<int64_t> dilations, strides;
-  if (!linalg::isaConvolutionOpOfType<Conv2DOp>(convOp, &dilations, &strides))
+  std::optional<DilationsAndStrides> convParams =
+      matchConvolutionOpOfType<Conv2DOp>(convOp);
+  if (!convParams)
     return failure();
 
   if (convOp.hasPureBufferSemantics())
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 7798cb76e4fb9..9c9aaf8c20b8d 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -1,170 +1,116 @@
 // RUN: mlir-opt --transform-interpreter --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --linalg-generalize-named-ops --transform-interpreter --split-input-file %s | FileCheck %s
 
 // CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
-#map_nhwc_hwcf_input = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
-#map_nhwc_hwcf_filter = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
-#map_nhwc_hwcf_output = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-
 // CHECK-LABEL: @conv_2d_nhwc_hwcf
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
 func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
-  // CHECK: tensor.extract_slice %[[ARG0]]
-  // CHECK: tensor.extract_slice %[[ARG1]]
-  // CHECK: tensor.extract_slice %[[ARG2]]
-  // Both named and generic ops should decompose to conv_1d_nwc_wcf
-  // CHECK-COUNT-2: linalg.conv_1d_nwc_wcf
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_nwc_wcf
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
                                  strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>)
     outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
-  // Generic op version with same semantics.
-  %1 = linalg.generic {indexing_maps = [#map_nhwc_hwcf_input, #map_nhwc_hwcf_filter, #map_nhwc_hwcf_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
-  ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %2 = arith.mulf %in, %in_0 : f32
-    %3 = arith.addf %out, %2 : f32
-    linalg.yield %3 : f32
-  } -> tensor<?x1x?x?xf32>
-  return %1 : tensor<?x1x?x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x1x?x?xf32>
 }
 
-#map_nchw_fchw_input = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
-#map_nchw_fchw_filter = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
-#map_nchw_fchw_output = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-
 // CHECK-LABEL: @conv_2d_nchw_fchw
 // CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
 // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
 // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
 func.func @conv_2d_nchw_fchw(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x?x1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
-  // CHECK: tensor.extract_slice %[[ARG0]]
-  // CHECK: tensor.extract_slice %[[ARG1]]
-  // CHECK: tensor.extract_slice %[[ARG2]]
-  // Both named and generic ops should decompose to conv_1d_ncw_fcw
-  // CHECK-COUNT-2: linalg.conv_1d_ncw_fcw
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_ncw_fcw
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>,
                                  strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>)
     outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
-  // Generic op version with same semantics.
-  %1 = linalg.generic {indexing_maps = [#map_nchw_fchw_input, #map_nchw_fchw_filter, #map_nchw_fchw_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>) outs(%0 : tensor<?x?x1x?xf32>) {
-  ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %2 = arith.mulf %in, %in_0 : f32
-    %3 = arith.addf %out, %2 : f32
-    linalg.yield %3 : f32
-  } -> tensor<?x?x1x?xf32>
-  return %1 : tensor<?x?x1x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x?x1x?xf32>
 }
 
-#map_depthwise_nhwc_hwc_input = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d3)>
-#map_depthwise_nhwc_hwc_filter = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
-#map_depthwise_nhwc_hwc_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-
 // CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
 // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>
 func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> {
   // CHECK: %[[RES:.+]] = tensor.empty
   %init = tensor.empty() : tensor<1x1x56x96xf32>
-  // CHECK: tensor.extract_slice %[[ARG0]]
-  // CHECK: tensor.extract_slice %[[ARG1]]
-  // Both named and generic ops should decompose to depthwise_conv_1d_nwc_wc
-  // CHECK-COUNT-2: linalg.depthwise_conv_1d_nwc_wc
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICERES:.+]] = tensor.extract_slice %[[RES]]
+  // CHECK: %[[OPRES:.+]] = linalg.depthwise_conv_1d_nwc_wc
+  // CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]]
+  // CHECK-SAME: outs(%[[SLICERES]]
+  // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[OPRES]] into %[[RES]]
   %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
          ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>)
          outs(%init: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32>
-  // Generic op version with same semantics (strides = 2).
-  %1 = linalg.generic {indexing_maps = [#map_depthwise_nhwc_hwc_input, #map_depthwise_nhwc_hwc_filter, #map_depthwise_nhwc_hwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<1x1x113x96xf32>, tensor<1x3x96xf32>) outs(%0 : tensor<1x1x56x96xf32>) {
-  ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %2 = arith.mulf %in, %in_0 : f32
-    %3 = arith.addf %out, %2 : f32
-    linalg.yield %3 : f32
-  } -> tensor<1x1x56x96xf32>
-  return %1: tensor<1x1x56x96xf32>
+  // CHECK: %[[INSERTED]]
+  return %0: tensor<1x1x56x96xf32>
 }
 
-#map_conv_2d_input = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
-#map_conv_2d_filter = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
-#map_conv_2d_output = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-
 // CHECK-LABEL: @conv_2d
 // CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<1x?xf32>,
 // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
 // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<1x?xf32>)
 func.func @conv_2d(%input: tensor<1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<1x?xf32>) -> tensor<1x?xf32> {
-  // CHECK: tensor.extract_slice %[[ARG0]]
-  // CHECK: tensor.extract_slice %[[ARG1]]
-  // CHECK: tensor.extract_slice %[[ARG2]]
-  // Both named and generic ops should decompose to conv_1d
-  // CHECK-COUNT-2: linalg.conv_1d
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.conv_1d
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.conv_2d
      ins (%input, %filter: tensor<1x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<1x?xf32>) -> tensor<1x?xf32>
-  // Generic op version with same semantics.
-  %1 = linalg.generic {indexing_maps = [#map_conv_2d_input, #map_conv_2d_filter, #map_conv_2d_output], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<1x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<1x?xf32>) {
-  ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %2 = arith.mulf %in, %in_0 : f32
-    %3 = arith.addf %out, %2 : f32
-    linalg.yield %3 : f32
-  } -> tensor<1x?xf32>
-  return %1 : tensor<1x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<1x?xf32>
 }
 
-#map_pooling_nhwc_input = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
-#map_pooling_nhwc_filter = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
-#map_pooling_nhwc_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-
 // CHECK-LABEL: @pooling_nhwc_sum
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
 func.func @pooling_nhwc_sum(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
-  // CHECK: tensor.extract_slice %[[ARG0]]
-  // CHECK: tensor.extract_slice %[[ARG1]]
-  // CHECK: tensor.extract_slice %[[ARG2]]
-  // Both named and generic ops should decompose to pooling_nwc_sum
-  // CHECK-COUNT-2: linalg.pooling_nwc_sum
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_sum
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
-  // Generic op version with same semantics.
-  %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
-  ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %2 = arith.addf %out, %in : f32
-    linalg.yield %2 : f32
-  } -> tensor<?x1x?x?xf32>
-  return %1 : tensor<?x1x?x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x1x?x?xf32>
 }
 
-#map_pooling_nchw_input = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
-#map_pooling_nchw_filter = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
-#map_pooling_nchw_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-
 // CHECK-LABEL: @pooling_nchw_sum
 // CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
 // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
 // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
 func.func @pooling_nchw_sum(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
-  // CHECK: tensor.extract_slice %[[ARG0]]
-  // CHECK: tensor.extract_slice %[[ARG1]]
-  // CHECK: tensor.extract_slice %[[ARG2]]
-  // Both named and generic ops should decompose to pooling_ncw_sum
-  // CHECK-COUNT-2: linalg.pooling_ncw_sum
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_sum
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
-  // Generic op version with same semantics.
-  %1 = linalg.generic {indexing_maps = [#map_pooling_nchw_input, #map_pooling_nchw_filter, #map_pooling_nchw_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x?x1x?xf32>) {
-  ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %2 = arith.addf %out, %in : f32
-    linalg.yield %2 : f32
-  } -> tensor<?x?x1x?xf32>
-  return %1 : tensor<?x?x1x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x?x1x?xf32>
 }
 
 // CHECK-LABEL: @pooling_nhwc_max
@@ -172,22 +118,17 @@ func.func @pooling_nchw_sum(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
 func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
-  // CHECK: tensor.extract_slice %[[ARG0]]
-  // CHECK: tensor.extract_slice %[[ARG1]]
-  // CHECK: tensor.extract_slice %[[ARG2]]
-  // Both named and generic ops should decompose to pooling_nwc_max
-  // CHECK-COUNT-2: linalg.pooling_nwc_max
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
-  // Generic op version with same semantics.
-  %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
-  ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %2 = arith.maximumf %out, %in : f32
-    linalg.yield %2 : f32
-  } -> tensor<?x1x?x?xf32>
-  return %1 : tensor<?x1x?x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x1x?x?xf32>
 }
 
 // CHECK-LABEL: @pooling_nhwc_max_unsigned
@@ -195,22 +136,17 @@ func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
 func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
-  // CHECK: tensor.extract_slice %[[ARG0]]
-  // CHECK: tensor.extract_slice %[[ARG1]]
-  // CHECK: tensor.extract_slice %[[ARG2]]
-  // Both named and generic ops should decompose to pooling_nwc_max_unsigned
-  // CHECK-COUNT-2: linalg.pooling_nwc_max_unsigned
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max_unsigned
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
     outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
-  // Generic op version with same semantics.
-  %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xi32>, tensor<1x?xi32>) outs(%0 : tensor<?x1x?x?xi32>) {
-  ^bb0(%in: i32, %in_0: i32, %out: i32):
-    %2 = arith.maxui %out, %in : i32
-    linalg.yield %2 : i32
-  } -> tensor<?x1x?x?xi32>
-  return %1 : tensor<?x1x?x?xi32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x1x?x?xi32>
 }
 
 // CHECK-LABEL: @pooling_nhwc_min
@@ -218,22 +154,17 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tenso
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
 func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
-  // CHECK: tensor.extract_slice %[[ARG0]]
-  // CHECK: tensor.extract_slice %[[ARG1]]
-  // CHECK: tensor.extract_slice %[[ARG2]]
-  // Both named and generic ops should decompose to pooling_nwc_min
-  // CHECK-COUNT-2: linalg.pooling_nwc_min
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
-  // Generic op version with same semantics.
-  %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
-  ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %2 = arith.minimumf %out, %in : f32
-    linalg.yield %2 : f32
-  } -> tensor<?x1x?x?xf32>
-  return %1 : tensor<?x1x?x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x1x?x?xf32>
 }
 
 // CHECK-LABEL: @pooling_nhwc_min_unsigned
@@ -241,22 +172,17 @@ func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
 func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
-  // CHECK: tensor.extract_slice %[[ARG0]]
-  // CHECK: tensor.extract_slice %[[ARG1]]
-  // CHECK: tensor.extract_slice %[[ARG2]]
-  // Both named and generic ops should decompose to pooling_nwc_min_unsigned
-  // CHECK-COUNT-2: linalg.pooling_nwc_min_unsigned
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min_unsigned
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
     outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
-  // Generic op version with same semantics.
-  %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xi32>, tensor<1x?xi32>) outs(%0 : tensor<?x1x?x?xi32>) {
-  ^bb0(%in: i32, %in_0: i32, %out: i32):
-    %2 = arith.minui %out, %in : i32
-    linalg.yield %2 : i32
-  } -> tensor<?x1x?x?xi32>
-  return %1 : tensor<?x1x?x?xi32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x1x?x?xi32>
 }
 
 // CHECK-LABEL: @pooling_nchw_max
@@ -264,22 +190,17 @@ func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tenso
 // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
 // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
 func.func @pooling_nchw_max(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
-  // CHECK: tensor.extract_slice %[[ARG0]]
-  // CHECK: tensor.extract_slice %[[ARG1]]
-  // CHECK: tensor.extract_slice %[[ARG2]]
-  // Both named and generic ops should decompose to pooling_ncw_max
-  // CHECK-COUNT-2: linalg.pooling_ncw_max
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_max
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
-  // Generic op version with same semantics.
-  %1 = linalg.generic {indexing_maps = [#map_pooling_nchw_input, #map_pooling_nchw_filter, #map_pooling_nchw_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x?x1x?xf32>) {
-  ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %2 = arith.maximumf %out, %in : f32
-    linalg.yield %2 : f32
-  } -> tensor<?x?x1x?xf32>
-  return %1 : tensor<?x?x1x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x?x1x?xf32>
 }
 
 func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {



More information about the Mlir-commits mailing list