[Mlir-commits] [mlir] [Linalg] Update transform + vectorization patterns to work with generic convolution ops as well (PR #174196)

Abhishek Varma llvmlistbot at llvm.org
Fri Jan 2 02:45:33 PST 2026


https://github.com/Abhishek-Varma created https://github.com/llvm/llvm-project/pull/174196

-- This commit updates Linalg Transforms and Vectorization patterns
    to work with both named as well as generic convolution ops.
-- This required the following updates to the `isaConvolutionOfType` API :-
    1. Allow dilations/strides to be optional arguments.
    2. Populate dilations/strides info for named convolution ops as well.
    3. Since now a "generic" LinalgOp is being used as the root op in the patterns
   above the `assert` of the op implementing a ConvolutionOpInterface has
   been replaced with an early exit `if`.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>

>From 0da5223a29c76b3c9c64040c602a93a804fe9846 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 1 Jan 2026 07:36:31 +0000
Subject: [PATCH 1/2] [Linalg] Update DecomposeConvolutionToLowerDimOpsPass to
 work with generic conv ops

-- This commit updates DecomposeConvolutionToLowerDimOpsPass to
   work with both named as well as generic convolution ops.
-- This required an update to the `isaConvolutionOfType` API to also
   populate dilations/strides info for named convolution ops and since
   now a generic LinalgOp is being used as the root op in the pattern
   above the assert of the op implementing a ConvolutionOpInterface has
   been replaced with an early exit if.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../Dialect/Linalg/Transforms/Transforms.h    |  27 +-
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 107 +++--
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 436 ++++++++++++------
 .../Linalg/transform-op-decompose.mlir        | 238 ++++++----
 4 files changed, 515 insertions(+), 293 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6678d693719bf..32067358438d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1641,16 +1641,16 @@ FailureOr<linalg::GenericOp> deduplicateOperandsAndRemoveDeadResults(
 //===----------------------------------------------------------------------===//
 
 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
-/// convolution ops.
+/// convolution ops. Works with both named ops and equivalent generic ops.
 template <typename Conv2DOp, typename Conv1DOp>
 struct DownscaleSizeOneWindowed2DConvolution final
-    : public OpRewritePattern<Conv2DOp> {
-  using OpRewritePattern<Conv2DOp>::OpRewritePattern;
+    : public OpInterfaceRewritePattern<LinalgOp> {
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
 
-  FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
+  FailureOr<Conv1DOp> returningMatchAndRewrite(LinalgOp convOp,
                                                PatternRewriter &rewriter) const;
 
-  LogicalResult matchAndRewrite(Conv2DOp convOp,
+  LogicalResult matchAndRewrite(LinalgOp convOp,
                                 PatternRewriter &rewriter) const override {
     return returningMatchAndRewrite(convOp, rewriter);
   }
@@ -1664,29 +1664,28 @@ extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
 /// dimensions into 1-D depthwise convolution ops.
 struct DownscaleDepthwiseConv2DNhwcHwcOp final
-    : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
+    : public OpInterfaceRewritePattern<LinalgOp> {
   DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context,
                                     PatternBenefit benefit = 1)
-      : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit) {}
+      : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
 
   FailureOr<DepthwiseConv1DNwcWcOp>
-  returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
-                           PatternRewriter &rewriter) const;
+  returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const;
 
-  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
+  LogicalResult matchAndRewrite(LinalgOp convOp,
                                 PatternRewriter &rewriter) const override {
     return returningMatchAndRewrite(convOp, rewriter);
   }
 };
 
-struct DownscaleConv2DOp final : public OpRewritePattern<Conv2DOp> {
+struct DownscaleConv2DOp final : public OpInterfaceRewritePattern<LinalgOp> {
   DownscaleConv2DOp(MLIRContext *context, PatternBenefit benefit = 1)
-      : OpRewritePattern<Conv2DOp>(context, benefit) {}
+      : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
 
-  FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
+  FailureOr<Conv1DOp> returningMatchAndRewrite(LinalgOp convOp,
                                                PatternRewriter &rewriter) const;
 
-  LogicalResult matchAndRewrite(Conv2DOp convOp,
+  LogicalResult matchAndRewrite(LinalgOp convOp,
                                 PatternRewriter &rewriter) const override {
     return returningMatchAndRewrite(convOp, rewriter);
   }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 96cc378f6c21a..7972408318b95 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -32,6 +32,7 @@
 #include "llvm/Support/DebugLog.h"
 #include "llvm/Support/InterleavedRange.h"
 #include "llvm/Support/raw_ostream.h"
+#include <type_traits>
 #include <utility>
 
 #define DEBUG_TYPE "linalg-transforms"
@@ -1406,13 +1407,18 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
 
 template <typename Conv2DOp, typename Conv1DOp>
 FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
-    returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
+    returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const {
+  // Check if this LinalgOp is of the expected Conv2DOp type (named or generic).
+  SmallVector<int64_t> dilations, strides;
+  if (!linalg::isaConvolutionOpOfType<Conv2DOp>(convOp, &dilations, &strides))
+    return failure();
+
   if (convOp.hasPureBufferSemantics())
     return failure(); // To be implemented.
 
-  Value input = convOp.getInputs().front();
-  Value kernel = convOp.getInputs().back();
-  Value output = convOp.getOutputs().front();
+  Value input = convOp.getDpsInputs().front();
+  Value kernel = convOp.getDpsInputs().back();
+  Value output = convOp.getDpsInits().front();
 
   auto inputType = dyn_cast<RankedTensorType>(input.getType());
   auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
@@ -1421,38 +1427,33 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
   auto kernelShape = kernelType.getShape();
   auto outputShape = outputType.getShape();
 
-  // Get domain indices based on conv2D layout.
-  auto [khIndex, kwIndex, ohIndex, owIndex] =
-      TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>(
-          convOp)
-          .Case([&](linalg::Conv2DNhwcHwcfOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::Conv2DNchwFchwOp op) {
-            return std::make_tuple(2, 3, 2, 3);
-          })
-          .Case([&](linalg::PoolingNhwcSumOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::PoolingNchwSumOp op) {
-            return std::make_tuple(0, 1, 2, 3);
-          })
-          .Case([&](linalg::PoolingNhwcMaxOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::PoolingNhwcMinOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::PoolingNchwMaxOp op) {
-            return std::make_tuple(0, 1, 2, 3);
-          })
-          .DefaultUnreachable("unexpected conv2d/pool2d operation.");
+  // Get domain indices based on Conv2DOp type. These are known at compile time.
+  int64_t khIndex, kwIndex, ohIndex, owIndex;
+  if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNhwcHwcfOp> ||
+                std::is_same_v<Conv2DOp, linalg::PoolingNhwcSumOp> ||
+                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxOp> ||
+                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxUnsignedOp> ||
+                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinOp> ||
+                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinUnsignedOp>) {
+    // NHWC layout: kernel [H, W, ...], output [N, H, W, C]
+    khIndex = 0;
+    kwIndex = 1;
+    ohIndex = 1;
+    owIndex = 2;
+  } else if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNchwFchwOp>) {
+    // NCHW_FCHW layout: kernel [..., H, W], output [N, C, H, W]
+    khIndex = 2;
+    kwIndex = 3;
+    ohIndex = 2;
+    owIndex = 3;
+  } else if constexpr (std::is_same_v<Conv2DOp, linalg::PoolingNchwSumOp> ||
+                       std::is_same_v<Conv2DOp, linalg::PoolingNchwMaxOp>) {
+    // NCHW pooling layout: kernel [H, W], output [N, C, H, W]
+    khIndex = 0;
+    kwIndex = 1;
+    ohIndex = 2;
+    owIndex = 3;
+  }
 
   // Only handle the case where at least one of the window dimensions is
   // of size 1. Other cases can rely on tiling to reduce to such cases.
@@ -1484,13 +1485,9 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
 
   // Rank-reduce strides and dilations too.
   // TODO: dropDim 1-liner helper.
-  auto strides =
-      llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
   strides.erase(strides.begin() + (removeH ? 0 : 1));
   auto stridesAttr = rewriter.getI64VectorAttr(strides);
 
-  auto dilations =
-      llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
 
@@ -1527,13 +1524,19 @@ template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
 
 FailureOr<DepthwiseConv1DNwcWcOp>
 DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
-    DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
+    LinalgOp convOp, PatternRewriter &rewriter) const {
+  // Check if this LinalgOp is a DepthwiseConv2DNhwcHwcOp (named or generic).
+  SmallVector<int64_t> dilations, strides;
+  if (!linalg::isaConvolutionOpOfType<DepthwiseConv2DNhwcHwcOp>(
+          convOp, &dilations, &strides))
+    return failure();
+
   if (convOp.hasPureBufferSemantics())
     return failure(); // To be implemented.
 
-  Value input = convOp.getInputs().front();
-  Value kernel = convOp.getInputs().back();
-  Value output = convOp.getOutputs().front();
+  Value input = convOp.getDpsInputs().front();
+  Value kernel = convOp.getDpsInputs().back();
+  Value output = convOp.getDpsInits().front();
 
   auto inputType = dyn_cast<RankedTensorType>(input.getType());
   auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
@@ -1572,12 +1575,9 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
 
   // Rank-reduce strides and dilations too.
   // TODO: dropDim 1-liner helper.
-  auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
   strides.erase(strides.begin() + (removeH ? 0 : 1));
   auto stridesAttr = rewriter.getI64VectorAttr(strides);
 
-  auto dilations =
-      llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
 
@@ -1594,14 +1594,19 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
 }
 
 FailureOr<Conv1DOp>
-DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
+DownscaleConv2DOp::returningMatchAndRewrite(LinalgOp convOp,
                                             PatternRewriter &rewriter) const {
+  // Check if this LinalgOp is a Conv2DOp (named or generic).
+  SmallVector<int64_t> dilations, strides;
+  if (!linalg::isaConvolutionOpOfType<Conv2DOp>(convOp, &dilations, &strides))
+    return failure();
+
   if (convOp.hasPureBufferSemantics())
     return failure(); // To be implemented.
 
-  Value input = convOp.getInputs().front();
-  Value kernel = convOp.getInputs().back();
-  Value output = convOp.getOutputs().front();
+  Value input = convOp.getDpsInputs().front();
+  Value kernel = convOp.getDpsInputs().back();
+  Value output = convOp.getDpsInits().front();
 
   auto inputType = dyn_cast<RankedTensorType>(input.getType());
   auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 2718124251c18..eccf648e53014 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -373,11 +373,8 @@ static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
                                                                   body);
 }
 
-// max_unsigned ops should not allow float data type.
-// TODO(#164800): Retire OPDSL logic.
 static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
-  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
-                                                                  body);
+  return bodyMatcherForPoolOps<arith::MaxUIOp>(yieldVal, body);
 }
 
 static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
@@ -385,11 +382,8 @@ static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
                                                                   body);
 }
 
-// min_unsigned ops should not allow float data type.
-// TODO(#164800): Retire OPDSL logic.
 static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
-  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
-                                                                  body);
+  return bodyMatcherForPoolOps<arith::MinUIOp>(yieldVal, body);
 }
 
 static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
@@ -601,11 +595,15 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
                                               SmallVector<int64_t> *dilations,
                                               SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv1DOp>(op))
+  if (isa<linalg::Conv1DOp>(op)) {
+    // Conv1DOp has no strides/dilations attributes, default to 1.
+    *dilations = SmallVector<int64_t>(1, 1);
+    *strides = SmallVector<int64_t>(1, 1);
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
   AffineExpr W = m.dim(0);
@@ -622,11 +620,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv1DNwcWcfOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv1DNwcWcfOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -646,11 +647,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv1DNcwFcwOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv1DNcwFcwOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -670,11 +674,15 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
                                               SmallVector<int64_t> *dilations,
                                               SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DOp>(op))
+  if (isa<linalg::Conv2DOp>(op)) {
+    // Conv2DOp has no strides/dilations attributes, default to 1.
+    *dilations = SmallVector<int64_t>(2, 1);
+    *strides = SmallVector<int64_t>(2, 1);
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr H = m.dim(0);
@@ -694,11 +702,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNhwcHwcfOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -721,11 +732,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNhwcHwcfQOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfQOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -750,11 +764,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNhwcFhwcOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -777,11 +794,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNhwcFhwcQOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcQOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -806,11 +826,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNchwFchwOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv2DNchwFchwOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -833,11 +856,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNchwFchwQOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv2DNchwFchwQOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -862,11 +888,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNgchwFgchwOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv2DNgchwFgchwOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -891,11 +920,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNgchwGfchwOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -920,11 +952,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNgchwGfchwQOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwQOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -951,11 +986,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNhwgcGfhwcOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -980,11 +1018,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcQOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -1011,11 +1052,15 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
                                               SmallVector<int64_t> *dilations,
                                               SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv3DOp>(op))
+  if (isa<linalg::Conv3DOp>(op)) {
+    // Conv3DOp has no strides/dilations attributes, default to 1.
+    *dilations = SmallVector<int64_t>(3, 1);
+    *strides = SmallVector<int64_t>(3, 1);
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
   AffineExpr D = m.dim(0);
@@ -1039,11 +1084,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv3DNdhwcDhwcfOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -1070,11 +1118,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfQOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -1103,11 +1154,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv3DNcdhwFcdhwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv3DNcdhwFcdhwOp>(op))
+  if (auto convOp = dyn_cast<linalg::Conv3DNcdhwFcdhwOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -1134,11 +1188,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
+  if (auto convOp = dyn_cast<linalg::DepthwiseConv1DNcwCwOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -1157,11 +1214,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+  if (auto convOp = dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -1180,11 +1240,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
+  if (auto convOp = dyn_cast<linalg::DepthwiseConv1DNwcWcmOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -1204,11 +1267,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+  if (auto convOp = dyn_cast<linalg::DepthwiseConv2DNchwChwOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -1230,11 +1296,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op))
+  if (auto convOp = dyn_cast<linalg::DepthwiseConv2DNhwcHwcOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -1256,11 +1325,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op))
+  if (auto convOp = dyn_cast<linalg::DepthwiseConv2DNhwcHwcQOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -1284,11 +1356,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op))
+  if (auto convOp = dyn_cast<linalg::DepthwiseConv2DNhwcHwcmOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -1311,11 +1386,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op))
+  if (auto convOp = dyn_cast<linalg::DepthwiseConv2DNhwcHwcmQOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -1340,11 +1418,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv3DNdhwcDhwcOp>(op))
+  if (auto convOp = dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -1370,11 +1451,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNcdhwCdhwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv3DNcdhwCdhwOp>(op))
+  if (auto convOp = dyn_cast<linalg::DepthwiseConv3DNcdhwCdhwOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -1400,11 +1484,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
+  if (auto convOp = dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
   AffineExpr N = m.dim(0);
@@ -1431,11 +1518,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNhwcMaxOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNhwcMaxOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
                        PoolingType::MaxSigned);
@@ -1458,11 +1548,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNhwcMinOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNhwcMinOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
                        PoolingType::MinSigned);
@@ -1485,11 +1578,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNhwcSumOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNhwcSumOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
                        PoolingType::Sum);
@@ -1512,11 +1608,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNhwcMaxUnsignedOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
                        PoolingType::MaxUnsigned);
@@ -1539,11 +1638,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNhwcMinUnsignedOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
                        PoolingType::MinUnsigned);
@@ -1566,11 +1668,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNchwSumOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNchwSumOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNchwSumOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
                        PoolingType::Sum);
@@ -1593,11 +1698,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNchwMaxOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNchwMaxOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNchwMaxOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
                        PoolingType::MaxSigned);
@@ -1620,11 +1728,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNwcSumOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNwcSumOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNwcSumOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
                        PoolingType::Sum);
@@ -1644,11 +1755,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNcwSumOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNcwSumOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNcwSumOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
                        PoolingType::Sum);
@@ -1668,11 +1782,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNwcMaxOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNwcMaxOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNwcMaxOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
                        PoolingType::MaxSigned);
@@ -1692,11 +1809,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNwcMaxUnsignedOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNwcMaxUnsignedOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNwcMaxUnsignedOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
                        PoolingType::MaxUnsigned);
@@ -1716,11 +1836,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNcwMaxOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNcwMaxOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNcwMaxOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
                        PoolingType::MaxSigned);
@@ -1740,11 +1863,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNwcMinOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNwcMinOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNwcMinOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
                        PoolingType::MinSigned);
@@ -1764,11 +1890,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNwcMinUnsignedOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNwcMinUnsignedOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNwcMinUnsignedOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
                        PoolingType::MinUnsigned);
@@ -1788,11 +1917,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNdhwcSumOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNdhwcSumOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNdhwcSumOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides,
                        PoolingType::Sum);
@@ -1819,11 +1951,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNdhwcMaxOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNdhwcMaxOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNdhwcMaxOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides,
                        PoolingType::MaxSigned);
@@ -1850,11 +1985,14 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNdhwcMinOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::PoolingNdhwcMinOp>(op))
+  if (auto poolOp = dyn_cast<linalg::PoolingNdhwcMinOp>(op.getOperation())) {
+    *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+    *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
+  }
 
-  assert(isaConvolutionOpInterface(op) &&
-         "expected op to implement ConvolutionOpInterface");
+  if (!isaConvolutionOpInterface(op))
+    return false;
 
   ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides,
                        PoolingType::MinSigned);
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 60a4c555fa19a..7798cb76e4fb9 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -3,113 +3,168 @@
 // CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
+#map_nhwc_hwcf_input = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+#map_nhwc_hwcf_filter = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+#map_nhwc_hwcf_output = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+
 // CHECK-LABEL: @conv_2d_nhwc_hwcf
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
 func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_nwc_wcf
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to conv_1d_nwc_wcf
+  // CHECK-COUNT-2: linalg.conv_1d_nwc_wcf
   %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
                                  strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>)
     outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_nhwc_hwcf_input, #map_nhwc_hwcf_filter, #map_nhwc_hwcf_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.mulf %in, %in_0 : f32
+    %3 = arith.addf %out, %2 : f32
+    linalg.yield %3 : f32
+  } -> tensor<?x1x?x?xf32>
+  return %1 : tensor<?x1x?x?xf32>
 }
 
+#map_nchw_fchw_input = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+#map_nchw_fchw_filter = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+#map_nchw_fchw_output = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+
 // CHECK-LABEL: @conv_2d_nchw_fchw
 // CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
 // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
 // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
 func.func @conv_2d_nchw_fchw(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x?x1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_ncw_fcw
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to conv_1d_ncw_fcw
+  // CHECK-COUNT-2: linalg.conv_1d_ncw_fcw
   %0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>,
                                  strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>)
     outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x?x1x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_nchw_fchw_input, #map_nchw_fchw_filter, #map_nchw_fchw_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>) outs(%0 : tensor<?x?x1x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.mulf %in, %in_0 : f32
+    %3 = arith.addf %out, %2 : f32
+    linalg.yield %3 : f32
+  } -> tensor<?x?x1x?xf32>
+  return %1 : tensor<?x?x1x?xf32>
 }
 
+#map_depthwise_nhwc_hwc_input = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d3)>
+#map_depthwise_nhwc_hwc_filter = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
+#map_depthwise_nhwc_hwc_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
 // CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
 // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>
 func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> {
   // CHECK: %[[RES:.+]] = tensor.empty
   %init = tensor.empty() : tensor<1x1x56x96xf32>
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICERES:.+]] = tensor.extract_slice %[[RES]]
-  // CHECK: %[[OPRES:.+]] = linalg.depthwise_conv_1d_nwc_wc
-  // CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]]
-  // CHECK-SAME: outs(%[[SLICERES]]
-  // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[OPRES]] into %[[RES]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // Both named and generic ops should decompose to depthwise_conv_1d_nwc_wc
+  // CHECK-COUNT-2: linalg.depthwise_conv_1d_nwc_wc
   %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
          ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>)
          outs(%init: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32>
-  // CHECK: %[[INSERTED]]
-  return %0: tensor<1x1x56x96xf32>
+  // Generic op version with same semantics (strides = 2).
+  %1 = linalg.generic {indexing_maps = [#map_depthwise_nhwc_hwc_input, #map_depthwise_nhwc_hwc_filter, #map_depthwise_nhwc_hwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<1x1x113x96xf32>, tensor<1x3x96xf32>) outs(%0 : tensor<1x1x56x96xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.mulf %in, %in_0 : f32
+    %3 = arith.addf %out, %2 : f32
+    linalg.yield %3 : f32
+  } -> tensor<1x1x56x96xf32>
+  return %1: tensor<1x1x56x96xf32>
 }
 
+#map_conv_2d_input = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
+#map_conv_2d_filter = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map_conv_2d_output = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+
 // CHECK-LABEL: @conv_2d
 // CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<1x?xf32>,
 // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
 // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<1x?xf32>)
 func.func @conv_2d(%input: tensor<1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<1x?xf32>) -> tensor<1x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.conv_1d
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to conv_1d
+  // CHECK-COUNT-2: linalg.conv_1d
   %0 = linalg.conv_2d
      ins (%input, %filter: tensor<1x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<1x?xf32>) -> tensor<1x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<1x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_conv_2d_input, #map_conv_2d_filter, #map_conv_2d_output], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<1x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<1x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.mulf %in, %in_0 : f32
+    %3 = arith.addf %out, %2 : f32
+    linalg.yield %3 : f32
+  } -> tensor<1x?xf32>
+  return %1 : tensor<1x?xf32>
 }
 
+#map_pooling_nhwc_input = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+#map_pooling_nhwc_filter = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+#map_pooling_nhwc_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
 // CHECK-LABEL: @pooling_nhwc_sum
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
 func.func @pooling_nhwc_sum(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_sum
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to pooling_nwc_sum
+  // CHECK-COUNT-2: linalg.pooling_nwc_sum
   %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.addf %out, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x1x?x?xf32>
+  return %1 : tensor<?x1x?x?xf32>
 }
 
+#map_pooling_nchw_input = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
+#map_pooling_nchw_filter = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+#map_pooling_nchw_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
 // CHECK-LABEL: @pooling_nchw_sum
 // CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
 // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
 // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
 func.func @pooling_nchw_sum(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_sum
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to pooling_ncw_sum
+  // CHECK-COUNT-2: linalg.pooling_ncw_sum
   %0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x?x1x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_pooling_nchw_input, #map_pooling_nchw_filter, #map_pooling_nchw_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x?x1x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.addf %out, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x1x?xf32>
+  return %1 : tensor<?x?x1x?xf32>
 }
 
 // CHECK-LABEL: @pooling_nhwc_max
@@ -117,17 +172,22 @@ func.func @pooling_nchw_sum(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
 func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to pooling_nwc_max
+  // CHECK-COUNT-2: linalg.pooling_nwc_max
   %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.maximumf %out, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x1x?x?xf32>
+  return %1 : tensor<?x1x?x?xf32>
 }
 
 // CHECK-LABEL: @pooling_nhwc_max_unsigned
@@ -135,17 +195,22 @@ func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
 func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max_unsigned
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to pooling_nwc_max_unsigned
+  // CHECK-COUNT-2: linalg.pooling_nwc_max_unsigned
   %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
     outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xi32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xi32>, tensor<1x?xi32>) outs(%0 : tensor<?x1x?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %2 = arith.maxui %out, %in : i32
+    linalg.yield %2 : i32
+  } -> tensor<?x1x?x?xi32>
+  return %1 : tensor<?x1x?x?xi32>
 }
 
 // CHECK-LABEL: @pooling_nhwc_min
@@ -153,17 +218,22 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tenso
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
 func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to pooling_nwc_min
+  // CHECK-COUNT-2: linalg.pooling_nwc_min
   %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.minimumf %out, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x1x?x?xf32>
+  return %1 : tensor<?x1x?x?xf32>
 }
 
 // CHECK-LABEL: @pooling_nhwc_min_unsigned
@@ -171,17 +241,22 @@ func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
 // CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
 func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min_unsigned
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to pooling_nwc_min_unsigned
+  // CHECK-COUNT-2: linalg.pooling_nwc_min_unsigned
   %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
     outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xi32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xi32>, tensor<1x?xi32>) outs(%0 : tensor<?x1x?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %2 = arith.minui %out, %in : i32
+    linalg.yield %2 : i32
+  } -> tensor<?x1x?x?xi32>
+  return %1 : tensor<?x1x?x?xi32>
 }
 
 // CHECK-LABEL: @pooling_nchw_max
@@ -189,17 +264,22 @@ func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tenso
 // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
 // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
 func.func @pooling_nchw_max(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
-  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
-  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
-  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
-  // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_max
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  // CHECK: tensor.extract_slice %[[ARG0]]
+  // CHECK: tensor.extract_slice %[[ARG1]]
+  // CHECK: tensor.extract_slice %[[ARG2]]
+  // Both named and generic ops should decompose to pooling_ncw_max
+  // CHECK-COUNT-2: linalg.pooling_ncw_max
   %0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
      ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
     outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
-  // CHECK: return %[[RES]]
-  return %0 : tensor<?x?x1x?xf32>
+  // Generic op version with same semantics.
+  %1 = linalg.generic {indexing_maps = [#map_pooling_nchw_input, #map_pooling_nchw_filter, #map_pooling_nchw_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x?x1x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.maximumf %out, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x1x?xf32>
+  return %1 : tensor<?x?x1x?xf32>
 }
 
 func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {

>From f3db3b69f7c7f4b4754b31bf5baf14f680a95b30 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 2 Jan 2026 09:42:58 +0000
Subject: [PATCH 2/2] [Linalg] Update Vectorization to work with generic
 convolution ops

-- This commit updates Vectorization to work with both named as well
   as generic convolution ops.
-- This required an update to the `isaConvolutionOfType` API to allow
   dilations/strides to be optional arguments.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../include/mlir/Dialect/Linalg/Utils/Utils.h |   8 +-
 .../Linalg/Transforms/Vectorization.cpp       |  96 ++++--
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 280 +++++++++++++++++-
 .../convolution-with-patterns-flatten.mlir    | 223 +++++++++++++-
 .../convolution-with-patterns.mlir            | 198 +++++++++++++
 .../Linalg/vectorization/convolution.mlir     | 147 +++++++++
 6 files changed, 908 insertions(+), 44 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 9da01f30b52d2..16d557a6ed7fa 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -108,10 +108,12 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
 
 /// Given a linalg `op` this function returns true if it is a convolution op of
 /// type `ConvOpTy` and populates `dilations` and `strides` with values inferred
-/// from the indexing maps.
+/// from the indexing maps. If `dilations` or `strides` is nullptr, the
+/// corresponding values are not populated.
 template <typename ConvOpTy>
-bool isaConvolutionOpOfType(LinalgOp op, SmallVector<int64_t> *dilations,
-                            SmallVector<int64_t> *strides);
+bool isaConvolutionOpOfType(LinalgOp op,
+                            SmallVector<int64_t> *dilations = nullptr,
+                            SmallVector<int64_t> *strides = nullptr);
 
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bb3bccdae0e14..0f9a7a1751699 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -2070,7 +2071,7 @@ vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
     return failure();
   }
 
-  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
+  if (!isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
     LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
     return failure();
   }
@@ -2431,10 +2432,11 @@ static LogicalResult vectorizeLinalgOpPrecondition(
   if (isElementwise(linalgOp))
     return success();
 
-  // TODO: isaConvolutionOpInterface that can also infer from generic
-  // features. But we will still need stride/dilation attributes that will be
-  // annoying to reverse-engineer...
-  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
+  // Check for convolution ops - both named ops implementing
+  // ConvolutionOpInterface and generic ops that semantically match convolution
+  // patterns.
+  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
+      isaConvolutionOpInterface(linalgOp))
     return vectorizeConvOpPrecondition(linalgOp);
 
   // TODO: the common vector shape is equal to the static loop sizes only when
@@ -2639,11 +2641,11 @@ vectorizeScalableVectorPrecondition(Operation *op,
 
   // Cond 4: Only the following ops are supported in the
   // presence of scalable vectors
-  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
-                 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
-                 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
-                 isa<linalg::BatchMmt4DOp>(op) ||
-                 hasReductionIterator(linalgOp));
+  return success(
+      isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
+      isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(linalgOp) ||
+      isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+      isa<linalg::BatchMmt4DOp>(op) || hasReductionIterator(linalgOp));
 }
 
 LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2734,7 +2736,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
             // TODO: isaConvolutionOpInterface that can also infer from
             // generic features. Will require stride/dilation attributes
             // inference.
-            if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+            if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
+                isaConvolutionOpInterface(linalgOp)) {
               FailureOr<Operation *> convOr = vectorizeConvolution(
                   rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
                   flatten1DDepthwiseConv);
@@ -3480,6 +3483,43 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
   bindShapeDims<0>(shapedType, vals...);
 }
 
+/// Helper to extract strides and dilations for 1D convolution/pooling ops.
+/// Returns true if the op is a recognized 1D conv/pool op and extracts the
+/// stride and dilation values. For unrecognized ops, returns false.
+static bool extract1DConvPoolStrideDilation(LinalgOp op, int &strideW,
+                                            int &dilationW) {
+  SmallVector<int64_t> dilations, strides;
+
+#define EXTRACT_1D_CONV_POOL_STRIDE_DILATION(ConvOpTy)                         \
+  if (isaConvolutionOpOfType<ConvOpTy>(op, &dilations, &strides)) {            \
+    strideW = static_cast<int>(strides.front());                               \
+    dilationW = static_cast<int>(dilations.front());                           \
+    return true;                                                               \
+  }
+
+  // 1D Convolution ops
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNwcWcfOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNcwFcwOp);
+  // Depthwise 1D Convolution ops
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNcwCwOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcmOp);
+  // 1D Pooling ops (NWC layout)
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcSumOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxUnsignedOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinUnsignedOp);
+  // 1D Pooling ops (NCW layout)
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwSumOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwMaxOp);
+
+#undef EXTRACT_1D_CONV_POOL_STRIDE_DILATION
+
+  return false;
+}
+
 namespace {
 /// Generate a vector implementation for either:
 /// ```
@@ -3535,14 +3575,19 @@ struct Conv1DGenerator
     auto maybeKind = getCombinerOpKind(reduceOp);
     reductionKind = maybeKind.value();
 
-    // The ConvolutionOpInterface gives us guarantees of existence for
-    // strides/dilations. However, we do not need to rely on those, we can
-    // simply use them if present, otherwise use the default and let the generic
-    // conv. matcher in the ConvGenerator succeed or fail.
-    auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
-    auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
-    strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
-    dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+    // Try to extract strides/dilations from named 1D conv/pool ops using
+    // isaConvolutionOpOfType. This works for both named ops and generic ops
+    // that match their semantics. For unrecognized generic ops, fall back to
+    // checking attributes directly (which may not exist for generic ops).
+    if (!extract1DConvPoolStrideDilation(linalgOp, strideW, dilationW)) {
+      // Fallback: check for stride/dilation attributes directly.
+      // For generic ops without these attributes, default to 1.
+      auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
+      auto dilations =
+          linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+      strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
+      dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+    }
   }
 
   /// Generate a vector implementation for:
@@ -4265,13 +4310,14 @@ static FailureOr<Operation *> vectorizeConvolution(
   if (!inputVecSizes.empty()) {
     // Only use the input vector size corresponding to the channel dim. Other
     // vector dims will be inferred from the Ops.
-    assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
-            isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
+    assert((isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(op) ||
+            isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(op)) &&
            "Not a 1D depthwise conv!");
-    size_t chDimIdx =
-        TypeSwitch<Operation *, size_t>(op)
-            .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
-            .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
+    size_t chDimIdx = 0;
+    if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(op))
+      chDimIdx = 2;
+    else if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(op))
+      chDimIdx = 1;
 
     vecChDimSize = inputVecSizes[chDimIdx];
     vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index eccf648e53014..1cdd01567c4e7 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -595,6 +595,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
                                               SmallVector<int64_t> *dilations,
                                               SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (isa<linalg::Conv1DOp>(op)) {
     // Conv1DOp has no strides/dilations attributes, default to 1.
     *dilations = SmallVector<int64_t>(1, 1);
@@ -620,6 +625,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv1DNwcWcfOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -647,6 +657,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv1DNcwFcwOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -674,6 +689,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
                                               SmallVector<int64_t> *dilations,
                                               SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (isa<linalg::Conv2DOp>(op)) {
     // Conv2DOp has no strides/dilations attributes, default to 1.
     *dilations = SmallVector<int64_t>(2, 1);
@@ -702,6 +722,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -732,6 +757,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfQOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -764,6 +794,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -794,6 +829,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcQOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -826,6 +866,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv2DNchwFchwOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -856,6 +901,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv2DNchwFchwQOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -888,6 +938,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv2DNgchwFgchwOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -920,6 +975,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -952,6 +1012,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwQOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -986,6 +1051,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -1018,6 +1088,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcQOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -1052,6 +1127,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
                                               SmallVector<int64_t> *dilations,
                                               SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (isa<linalg::Conv3DOp>(op)) {
     // Conv3DOp has no strides/dilations attributes, default to 1.
     *dilations = SmallVector<int64_t>(3, 1);
@@ -1084,6 +1164,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -1118,6 +1203,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfQOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -1154,6 +1244,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::Conv3DNcdhwFcdhwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto convOp = dyn_cast<linalg::Conv3DNcdhwFcdhwOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
@@ -1188,7 +1283,13 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (auto convOp = dyn_cast<linalg::DepthwiseConv1DNcwCwOp>(op.getOperation())) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
+  if (auto convOp =
+          dyn_cast<linalg::DepthwiseConv1DNcwCwOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
@@ -1214,7 +1315,13 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (auto convOp = dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
+  if (auto convOp =
+          dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
@@ -1240,7 +1347,13 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (auto convOp = dyn_cast<linalg::DepthwiseConv1DNwcWcmOp>(op.getOperation())) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
+  if (auto convOp =
+          dyn_cast<linalg::DepthwiseConv1DNwcWcmOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
@@ -1267,7 +1380,13 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (auto convOp = dyn_cast<linalg::DepthwiseConv2DNchwChwOp>(op.getOperation())) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
+  if (auto convOp =
+          dyn_cast<linalg::DepthwiseConv2DNchwChwOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
@@ -1296,7 +1415,13 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (auto convOp = dyn_cast<linalg::DepthwiseConv2DNhwcHwcOp>(op.getOperation())) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
+  if (auto convOp =
+          dyn_cast<linalg::DepthwiseConv2DNhwcHwcOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
@@ -1325,7 +1450,13 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (auto convOp = dyn_cast<linalg::DepthwiseConv2DNhwcHwcQOp>(op.getOperation())) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
+  if (auto convOp =
+          dyn_cast<linalg::DepthwiseConv2DNhwcHwcQOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
@@ -1356,7 +1487,13 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (auto convOp = dyn_cast<linalg::DepthwiseConv2DNhwcHwcmOp>(op.getOperation())) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
+  if (auto convOp =
+          dyn_cast<linalg::DepthwiseConv2DNhwcHwcmOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
@@ -1386,7 +1523,13 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmQOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (auto convOp = dyn_cast<linalg::DepthwiseConv2DNhwcHwcmQOp>(op.getOperation())) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
+  if (auto convOp =
+          dyn_cast<linalg::DepthwiseConv2DNhwcHwcmQOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
@@ -1418,7 +1561,13 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (auto convOp = dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcOp>(op.getOperation())) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
+  if (auto convOp =
+          dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
@@ -1451,7 +1600,13 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNcdhwCdhwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (auto convOp = dyn_cast<linalg::DepthwiseConv3DNcdhwCdhwOp>(op.getOperation())) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
+  if (auto convOp =
+          dyn_cast<linalg::DepthwiseConv3DNcdhwCdhwOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
@@ -1484,7 +1639,13 @@ template <>
 bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (auto convOp = dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op.getOperation())) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
+  if (auto convOp =
+          dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op.getOperation())) {
     *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
     return true;
@@ -1518,6 +1679,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto poolOp = dyn_cast<linalg::PoolingNhwcMaxOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
@@ -1548,6 +1714,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto poolOp = dyn_cast<linalg::PoolingNhwcMinOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
@@ -1578,6 +1749,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto poolOp = dyn_cast<linalg::PoolingNhwcSumOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
@@ -1608,7 +1784,13 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (auto poolOp = dyn_cast<linalg::PoolingNhwcMaxUnsignedOp>(op.getOperation())) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
+  if (auto poolOp =
+          dyn_cast<linalg::PoolingNhwcMaxUnsignedOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
@@ -1638,7 +1820,13 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (auto poolOp = dyn_cast<linalg::PoolingNhwcMinUnsignedOp>(op.getOperation())) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
+  if (auto poolOp =
+          dyn_cast<linalg::PoolingNhwcMinUnsignedOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
@@ -1668,6 +1856,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNchwSumOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto poolOp = dyn_cast<linalg::PoolingNchwSumOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
@@ -1698,6 +1891,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNchwMaxOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto poolOp = dyn_cast<linalg::PoolingNchwMaxOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
@@ -1728,6 +1926,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNwcSumOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto poolOp = dyn_cast<linalg::PoolingNwcSumOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
@@ -1755,6 +1958,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNcwSumOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto poolOp = dyn_cast<linalg::PoolingNcwSumOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
@@ -1782,6 +1990,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNwcMaxOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto poolOp = dyn_cast<linalg::PoolingNwcMaxOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
@@ -1809,7 +2022,13 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNwcMaxUnsignedOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (auto poolOp = dyn_cast<linalg::PoolingNwcMaxUnsignedOp>(op.getOperation())) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
+  if (auto poolOp =
+          dyn_cast<linalg::PoolingNwcMaxUnsignedOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
@@ -1836,6 +2055,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNcwMaxOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto poolOp = dyn_cast<linalg::PoolingNcwMaxOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
@@ -1863,6 +2087,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNwcMinOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto poolOp = dyn_cast<linalg::PoolingNwcMinOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
@@ -1890,7 +2119,13 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNwcMinUnsignedOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (auto poolOp = dyn_cast<linalg::PoolingNwcMinUnsignedOp>(op.getOperation())) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
+  if (auto poolOp =
+          dyn_cast<linalg::PoolingNwcMinUnsignedOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
     return true;
@@ -1917,6 +2152,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNdhwcSumOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto poolOp = dyn_cast<linalg::PoolingNdhwcSumOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
@@ -1951,6 +2191,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNdhwcMaxOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto poolOp = dyn_cast<linalg::PoolingNdhwcMaxOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
@@ -1985,6 +2230,11 @@ template <>
 bool isaConvolutionOpOfType<linalg::PoolingNdhwcMinOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
+  SmallVector<int64_t> localDilations, localStrides;
+  if (!dilations)
+    dilations = &localDilations;
+  if (!strides)
+    strides = &localStrides;
   if (auto poolOp = dyn_cast<linalg::PoolingNdhwcMinOp>(op.getOperation())) {
     *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
     *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir
index c47824a18cf56..40641bfed659f 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir
@@ -52,7 +52,62 @@ module attributes {transform.with_named_sequence} {
 // CHECK:           %[[SC_ADDI:.*]] = vector.shape_cast %[[ADDI]] : vector<1x24xi8> to vector<1x8x3xi8>
 // CHECK:           vector.transfer_write %[[SC_ADDI]], %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
 
-//------
+// -----
+
+// Generic version of depthwise_conv_1d_nwc_wc with stride=1, dilation=1
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_generic(%input: tensor<1x8x3xi8>,
+                                                           %filter: tensor<1x3xi8>,
+                                                           %output: tensor<1x8x3xi8>) -> (tensor<1x8x3xi8>) {
+  %res = linalg.generic {
+    indexing_maps = [
+      affine_map<(n, ow, c, kw) -> (n, ow + kw, c)>,
+      affine_map<(n, ow, c, kw) -> (kw, c)>,
+      affine_map<(n, ow, c, kw) -> (n, ow, c)>
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+  } ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
+    outs(%output : tensor<1x8x3xi8>) {
+  ^bb0(%in: i8, %flt: i8, %out: i8):
+    %mul = arith.muli %in, %flt : i8
+    %add = arith.addi %out, %mul : i8
+    linalg.yield %add : i8
+  } -> tensor<1x8x3xi8>
+  return %res : tensor<1x8x3xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_generic
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<1x8x3xi8>,
+// CHECK-SAME:      %[[FILTER:.*]]: tensor<1x3xi8>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x8x3xi8>) -> tensor<1x8x3xi8> {
+
+// CHECK-DAG:       %[[C0_IDX:.*]] = arith.constant 0 : index
+
+/// Read the whole data in one shot.
+// CHECK:           vector.transfer_read %[[INPUT]]
+// CHECK:           vector.transfer_read %[[FILTER]]
+// CHECK:           vector.transfer_read %[[OUTPUT]]
+
+/// Check for flattened depthwise conv vectorization pattern
+// CHECK:           vector.shape_cast {{.*}} : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK:           vector.shuffle
+// CHECK:           vector.broadcast
+// CHECK:           arith.muli {{.*}} : vector<1x24xi8>
+// CHECK:           arith.addi {{.*}} : vector<1x24xi8>
+
+/// Write the result back.
+// CHECK:           vector.shape_cast {{.*}} : vector<1x24xi8> to vector<1x8x3xi8>
+// CHECK:           vector.transfer_write
+
+// -----
 
 func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x4xf32>,
                                                                 %filter: memref<2x4xf32>,
@@ -115,6 +170,60 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Generic version of depthwise_conv_1d_nwc_wc with stride=1, dilation=2 (f32)
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2_generic(%input: memref<3x5x4xf32>,
+                                                                         %filter: memref<2x4xf32>,
+                                                                         %output: memref<3x2x4xf32>) {
+  // dilation=2 means: ow + 2*kw in the input indexing map
+  linalg.generic {
+    indexing_maps = [
+      affine_map<(n, ow, c, kw) -> (n, ow + 2*kw, c)>,
+      affine_map<(n, ow, c, kw) -> (kw, c)>,
+      affine_map<(n, ow, c, kw) -> (n, ow, c)>
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+  } ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
+    outs(%output : memref<3x2x4xf32>) {
+  ^bb0(%in: f32, %flt: f32, %out: f32):
+    %mul = arith.mulf %in, %flt : f32
+    %add = arith.addf %out, %mul : f32
+    linalg.yield %add : f32
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2_generic
+//  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>)
+
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+
+/// Read the whole data in one shot.
+//      CHECK:   vector.transfer_read %[[INPUT]]
+//      CHECK:   vector.transfer_read %[[FILTER]]
+//      CHECK:   vector.transfer_read %[[OUTPUT]]
+
+/// Check for flattened depthwise conv vectorization pattern with dilation=2
+// CHECK:        vector.shape_cast {{.*}} : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK:        vector.shuffle
+// CHECK:        vector.broadcast
+// CHECK:        vector.fma {{.*}} : vector<3x8xf32>
+// CHECK:        vector.fma {{.*}} : vector<3x8xf32>
+
+/// Write the result back.
+// CHECK:        vector.shape_cast {{.*}} : vector<3x8xf32> to vector<3x2x4xf32>
+// CHECK:        vector.transfer_write {{.*}}, %[[OUTPUT]]
+
+// -----
+
 func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2(%input: memref<3x5x4xi8>,
                                                               %filter: memref<2x4xi8>,
                                                               %output: memref<3x2x4xi32>) {
@@ -179,6 +288,64 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Generic version of depthwise_conv_1d_nwc_wc with stride=1, dilation=2 (i8->i32)
+func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2_generic(%input: memref<3x5x4xi8>,
+                                                                       %filter: memref<2x4xi8>,
+                                                                       %output: memref<3x2x4xi32>) {
+  // dilation=2 means: ow + 2*kw in the input indexing map
+  linalg.generic {
+    indexing_maps = [
+      affine_map<(n, ow, c, kw) -> (n, ow + 2*kw, c)>,
+      affine_map<(n, ow, c, kw) -> (kw, c)>,
+      affine_map<(n, ow, c, kw) -> (n, ow, c)>
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+  } ins(%input, %filter : memref<3x5x4xi8>, memref<2x4xi8>)
+    outs(%output : memref<3x2x4xi32>) {
+  ^bb0(%in: i8, %flt: i8, %out: i32):
+    %in_ext = arith.extsi %in : i8 to i32
+    %flt_ext = arith.extsi %flt : i8 to i32
+    %mul = arith.muli %in_ext, %flt_ext : i32
+    %add = arith.addi %out, %mul : i32
+    linalg.yield %add : i32
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2_generic
+//  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xi8>, %[[FILTER:[0-9a-z]+]]: memref<2x4xi8>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xi32>)
+
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+
+/// Read the whole data in one shot.
+//      CHECK:   vector.transfer_read %[[INPUT]]
+//      CHECK:   vector.transfer_read %[[FILTER]]
+//      CHECK:   vector.transfer_read %[[OUTPUT]]
+
+/// Check for flattened depthwise conv vectorization pattern with dilation=2 and i8->i32 extension
+// CHECK:        vector.shape_cast
+// CHECK:        arith.extsi {{.*}} : vector<3x8xi8> to vector<3x8xi32>
+// CHECK:        vector.shuffle
+// CHECK:        arith.extsi
+// CHECK:        vector.broadcast
+// CHECK:        arith.muli {{.*}} : vector<3x8xi32>
+// CHECK:        arith.addi {{.*}} : vector<3x8xi32>
+
+/// Write the result back.
+// CHECK:        vector.shape_cast {{.*}} : vector<3x8xi32> to vector<3x2x4xi32>
+// CHECK:        vector.transfer_write {{.*}}, %[[OUTPUT]]
+
+// -----
+
 func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2(%input: tensor<3x9x4xi8>,
                                                             %filter: tensor<3x4xi8>,
                                                             %output: tensor<3x3x4xi8>) -> tensor<3x3x4xi8> {
@@ -309,3 +476,57 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+// Test that generic ops with depthwise conv semantics work with flattening.
+// Generic version of depthwise_conv_1d_nwc_wc with stride=2, dilation=1
+
+func.func @depthwise_conv1d_nwc_wc_generic_flatten(%input: tensor<3x9x4xi8>,
+                                                    %filter: tensor<3x4xi8>,
+                                                    %output: tensor<3x3x4xi8>) -> tensor<3x3x4xi8> {
+  %res = linalg.generic {
+    indexing_maps = [
+      affine_map<(n, ow, c, kw) -> (n, ow * 2 + kw, c)>,  // input (stride=2, dilation=1)
+      affine_map<(n, ow, c, kw) -> (kw, c)>,              // filter
+      affine_map<(n, ow, c, kw) -> (n, ow, c)>            // output
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+  } ins(%input, %filter : tensor<3x9x4xi8>, tensor<3x4xi8>)
+    outs(%output : tensor<3x3x4xi8>) {
+  ^bb0(%in: i8, %flt: i8, %out: i8):
+    %mul = arith.muli %in, %flt : i8
+    %add = arith.addi %out, %mul : i8
+    linalg.yield %add : i8
+  } -> tensor<3x3x4xi8>
+  return %res : tensor<3x3x4xi8>
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_generic_flatten
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<3x9x4xi8>,
+// CHECK-SAME:      %[[FILTER:.*]]: tensor<3x4xi8>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<3x3x4xi8>) -> tensor<3x3x4xi8> {
+
+// CHECK-DAG:       %[[C0_IDX:.*]] = arith.constant 0 : index
+
+/// Read the whole data in one shot.
+// CHECK:           vector.transfer_read %[[INPUT]]
+// CHECK:           vector.transfer_read %[[FILTER]]
+// CHECK:           vector.transfer_read %[[OUTPUT]]
+
+/// Check for depthwise conv vectorization pattern (shape_cast + broadcast + muli + addi)
+// CHECK:           vector.shape_cast
+// CHECK:           vector.broadcast
+// CHECK:           arith.muli
+// CHECK:           arith.addi
+
+/// Write the result back.
+// CHECK:           vector.transfer_write
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
index cea60842f4606..dc63aa1fe2a48 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
@@ -1213,3 +1213,201 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// Test that generic ops with convolution semantics are vectorized.
+// This is a generic version of conv_1d_nwc_wcf with non-trivial stride:
+// Input: (N=4, IW=6, C=3), Filter: (KW=1, C=3, F=8), Output: (N=4, OW=2, F=8)
+// strides=3, dilations=1 (matching the first named conv test)
+
+func.func @conv1d_nwc_generic(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf32>, %output: memref<4x2x8xf32>) {
+  linalg.generic {
+    indexing_maps = [
+      affine_map<(n, ow, f, kw, c) -> (n, ow * 3 + kw, c)>,  // input (stride=3, dilation=1)
+      affine_map<(n, ow, f, kw, c) -> (kw, c, f)>,           // filter
+      affine_map<(n, ow, f, kw, c) -> (n, ow, f)>            // output
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
+  } ins(%input, %filter : memref<4x6x3xf32>, memref<1x3x8xf32>)
+    outs(%output : memref<4x2x8xf32>) {
+  ^bb0(%in: f32, %flt: f32, %out: f32):
+    %mul = arith.mulf %in, %flt : f32
+    %add = arith.addf %out, %mul : f32
+    linalg.yield %add : f32
+  }
+  return
+}
+
+// CHECK-LABEL: func @conv1d_nwc_generic
+// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<1x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>)
+
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+
+/// Read the whole data in one shot.
+//  CHECK-DAG:   vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:   vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:   vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+
+/// Check for vector.contract operation (convolution vectorized with stride=3)
+//      CHECK:   vector.contract
+
+/// Write the result back in one shot.
+//      CHECK:   vector.transfer_write {{.*}}, %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Test that generic ops with depthwise convolution semantics are vectorized.
+// This is a generic version of depthwise_conv_1d_nwc_wc with non-trivial dilation:
+// Input: (N=3, IW=5, C=4), Filter: (KW=2, C=4), Output: (N=3, OW=2, C=4)
+// strides=1, dilations=2 (matching the named depthwise conv test)
+
+func.func @depthwise_conv1d_nwc_wc_generic(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
+  linalg.generic {
+    indexing_maps = [
+      affine_map<(n, ow, c, kw) -> (n, ow + kw * 2, c)>,  // input (stride=1, dilation=2)
+      affine_map<(n, ow, c, kw) -> (kw, c)>,              // filter
+      affine_map<(n, ow, c, kw) -> (n, ow, c)>            // output
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+  } ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
+    outs(%output : memref<3x2x4xf32>) {
+  ^bb0(%in: f32, %flt: f32, %out: f32):
+    %mul = arith.mulf %in, %flt : f32
+    %add = arith.addf %out, %mul : f32
+    linalg.yield %add : f32
+  }
+  return
+}
+
+// CHECK-LABEL: func @depthwise_conv1d_nwc_wc_generic
+// CHECK-SAME: (%[[INPUT:.+]]: memref<3x5x4xf32>, %[[FILTER:.+]]: memref<2x4xf32>, %[[OUTPUT:.+]]: memref<3x2x4xf32>)
+
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+
+/// Read the whole data in one shot.
+//  CHECK-DAG:   vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:   vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:   vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+
+/// Check for vector.fma operation (depthwise conv vectorized with dilation=2)
+//      CHECK:   vector.fma
+
+/// Write the result back in one shot.
+//      CHECK:   vector.transfer_write {{.*}}, %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Test that generic ops with pooling_nwc_sum semantics are vectorized.
+// Generic version of pooling_nwc_sum with stride=3, dilation=1
+// Input: (N=4, IW=4, C=3), Filter: (KW=1), Output: (N=4, OW=2, C=3)
+
+func.func @pooling_nwc_sum_generic(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {
+  linalg.generic {
+    indexing_maps = [
+      affine_map<(n, ow, c, kw) -> (n, ow * 3 + kw, c)>,  // input (stride=3, dilation=1)
+      affine_map<(n, ow, c, kw) -> (kw)>,                  // filter (window shape)
+      affine_map<(n, ow, c, kw) -> (n, ow, c)>             // output
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+  } ins(%input, %filter : memref<4x4x3xf32>, memref<1xf32>)
+    outs(%output : memref<4x2x3xf32>) {
+  ^bb0(%in: f32, %flt: f32, %out: f32):
+    %add = arith.addf %out, %in : f32
+    linalg.yield %add : f32
+  }
+  return
+}
+
+// CHECK-LABEL: func @pooling_nwc_sum_generic
+// CHECK-SAME: (%[[INPUT:.+]]: memref<4x4x3xf32>, %[[FILTER:.+]]: memref<1xf32>, %[[OUTPUT:.+]]: memref<4x2x3xf32>)
+
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+
+/// Read the whole data in one shot.
+//  CHECK-DAG:   vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:   vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+
+/// Check for arith.addf operation (sum pooling)
+//      CHECK:   arith.addf
+
+/// Write the result back in one shot.
+//      CHECK:   vector.transfer_write {{.*}}, %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Test that generic ops with pooling_nwc_max semantics are vectorized.
+// Generic version of pooling_nwc_max with stride=3, dilation=1
+// Input: (N=4, IW=4, C=3), Filter: (KW=1), Output: (N=4, OW=2, C=3)
+
+func.func @pooling_nwc_max_generic(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {
+  linalg.generic {
+    indexing_maps = [
+      affine_map<(n, ow, c, kw) -> (n, ow * 3 + kw, c)>,  // input (stride=3, dilation=1)
+      affine_map<(n, ow, c, kw) -> (kw)>,                  // filter (window shape)
+      affine_map<(n, ow, c, kw) -> (n, ow, c)>             // output
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+  } ins(%input, %filter : memref<4x4x3xf32>, memref<1xf32>)
+    outs(%output : memref<4x2x3xf32>) {
+  ^bb0(%in: f32, %flt: f32, %out: f32):
+    %max = arith.maximumf %out, %in : f32
+    linalg.yield %max : f32
+  }
+  return
+}
+
+// CHECK-LABEL: func @pooling_nwc_max_generic
+// CHECK-SAME: (%[[INPUT:.+]]: memref<4x4x3xf32>, %[[FILTER:.+]]: memref<1xf32>, %[[OUTPUT:.+]]: memref<4x2x3xf32>)
+
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+
+/// Read the whole data in one shot.
+//  CHECK-DAG:   vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:   vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+
+/// Check for arith.maximumf operation (max pooling)
+//      CHECK:   arith.maximumf
+
+/// Write the result back in one shot.
+//      CHECK:   vector.transfer_write {{.*}}, %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution.mlir
index 5c321d40f6c60..fd5554b169860 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution.mlir
@@ -70,6 +70,54 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Generic version of depthwise_conv_1d_nwc_wc with dynamic channel dim.
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_generic(%input: tensor<1x8x?xi8>,
+                                                           %filter: tensor<1x?xi8>,
+                                                           %output: tensor<1x8x?xi8>) -> (tensor<1x8x?xi8>) {
+  %res = linalg.generic {
+    indexing_maps = [
+      affine_map<(n, ow, c, kw) -> (n, ow + kw, c)>,
+      affine_map<(n, ow, c, kw) -> (kw, c)>,
+      affine_map<(n, ow, c, kw) -> (n, ow, c)>
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+  } ins(%input, %filter : tensor<1x8x?xi8>, tensor<1x?xi8>)
+    outs(%output : tensor<1x8x?xi8>) {
+  ^bb0(%in: i8, %flt: i8, %out: i8):
+    %mul = arith.muli %in, %flt : i8
+    %add = arith.addi %out, %mul : i8
+    linalg.yield %add : i8
+  } -> tensor<1x8x?xi8>
+  return %res : tensor<1x8x?xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [1, 8, 4, 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_generic(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<1x8x?xi8>,
+// CHECK-SAME:      %[[FILTER:.*]]: tensor<1x?xi8>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
+
+/// Same vectorized output as named op version
+// CHECK:           vector.create_mask {{.*}} : vector<1x8x4xi1>
+// CHECK:           vector.mask {{.*}} { vector.transfer_read %[[INPUT]]{{.*}} } : vector<1x8x4xi1> -> vector<1x8x4xi8>
+// CHECK:           vector.create_mask {{.*}} : vector<1x4xi1>
+// CHECK:           vector.mask {{.*}} { vector.transfer_read %[[FILTER]]{{.*}} } : vector<1x4xi1> -> vector<1x4xi8>
+// CHECK:           vector.create_mask {{.*}} : vector<1x8x4xi1>
+// CHECK:           vector.mask {{.*}} { vector.transfer_read %[[OUTPUT]]{{.*}} } : vector<1x8x4xi1> -> vector<1x8x4xi8>
+// CHECK:           vector.broadcast {{.*}} : vector<4xi8> to vector<1x8x4xi8>
+// CHECK:           arith.muli {{.*}} : vector<1x8x4xi8>
+// CHECK:           arith.addi {{.*}} : vector<1x8x4xi8>
+// CHECK:           vector.mask {{.*}} { vector.transfer_write {{.*}} } : vector<1x8x4xi1> -> tensor<1x8x?xi8>
+
+// -----
+
 func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_scalable(
       %input: tensor<1x8x?xi8>,
       %filter: tensor<1x?xi8>,
@@ -132,6 +180,55 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Generic version with scalable vectors.
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_scalable_generic(
+      %input: tensor<1x8x?xi8>,
+      %filter: tensor<1x?xi8>,
+      %output: tensor<1x8x?xi8>) -> (tensor<1x8x?xi8>) {
+  %res = linalg.generic {
+    indexing_maps = [
+      affine_map<(n, ow, c, kw) -> (n, ow + kw, c)>,
+      affine_map<(n, ow, c, kw) -> (kw, c)>,
+      affine_map<(n, ow, c, kw) -> (n, ow, c)>
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+  } ins(%input, %filter : tensor<1x8x?xi8>, tensor<1x?xi8>)
+    outs(%output : tensor<1x8x?xi8>) {
+  ^bb0(%in: i8, %flt: i8, %out: i8):
+    %mul = arith.muli %in, %flt : i8
+    %add = arith.addi %out, %mul : i8
+    linalg.yield %add : i8
+  } -> tensor<1x8x?xi8>
+  return %res : tensor<1x8x?xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [1, 8, [4], 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_scalable_generic(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<1x8x?xi8>,
+// CHECK-SAME:      %[[FILTER:.*]]: tensor<1x?xi8>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
+
+/// Same scalable vectorized output as named op version
+// CHECK:           vector.create_mask {{.*}} : vector<1x8x[4]xi1>
+// CHECK:           vector.mask {{.*}} { vector.transfer_read %[[INPUT]]{{.*}} } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+// CHECK:           vector.create_mask {{.*}} : vector<1x[4]xi1>
+// CHECK:           vector.mask {{.*}} { vector.transfer_read %[[FILTER]]{{.*}} } : vector<1x[4]xi1> -> vector<1x[4]xi8>
+// CHECK:           vector.create_mask {{.*}} : vector<1x8x[4]xi1>
+// CHECK:           vector.mask {{.*}} { vector.transfer_read %[[OUTPUT]]{{.*}} } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+// CHECK:           vector.broadcast {{.*}} : vector<[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           arith.muli {{.*}} : vector<1x8x[4]xi8>
+// CHECK:           arith.addi {{.*}} : vector<1x8x[4]xi8>
+// CHECK:           vector.mask {{.*}} { vector.transfer_write {{.*}} } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
+
+// -----
+
 func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
       %input: memref<3x5x?xf32>,
       %filter: memref<2x?xf32>,
@@ -193,3 +290,53 @@ module attributes {transform.with_named_sequence} {
 // CHECK:           %[[FMA_2:.*]] = vector.fma %[[IN_2]], %[[FLT_2_B]], %[[FMA_1]] : vector<3x2x[4]xf32>
 // CHECK:           %[[OUT_INS:.*]] = vector.insert_strided_slice %[[FMA_2]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
 // CHECK:           vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
+
+// -----
+
+// Generic version with dilation=2.
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2_generic(
+      %input: memref<3x5x?xf32>,
+      %filter: memref<2x?xf32>,
+      %output: memref<3x2x?xf32>) {
+  // dilation=2 means: ow + 2*kw in the input indexing map
+  linalg.generic {
+    indexing_maps = [
+      affine_map<(n, ow, c, kw) -> (n, ow + 2*kw, c)>,
+      affine_map<(n, ow, c, kw) -> (kw, c)>,
+      affine_map<(n, ow, c, kw) -> (n, ow, c)>
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+  } ins(%input, %filter : memref<3x5x?xf32>, memref<2x?xf32>)
+    outs(%output : memref<3x2x?xf32>) {
+  ^bb0(%in: f32, %flt: f32, %out: f32):
+    %mul = arith.mulf %in, %flt : f32
+    %add = arith.addf %out, %mul : f32
+    linalg.yield %add : f32
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [3, 2, [4], 2] : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2_generic(
+// CHECK-SAME:      %[[INPUT:.*]]: memref<3x5x?xf32>,
+// CHECK-SAME:      %[[FILTER:.*]]: memref<2x?xf32>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: memref<3x2x?xf32>) {
+
+/// Same vectorized output as named op version with dilation=2
+// CHECK:           vector.create_mask {{.*}} : vector<3x4x[4]xi1>
+// CHECK:           vector.mask {{.*}} { vector.transfer_read %[[INPUT]]{{.*}} } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
+// CHECK:           vector.create_mask {{.*}} : vector<2x[4]xi1>
+// CHECK:           vector.mask {{.*}} { vector.transfer_read %[[FILTER]]{{.*}} } : vector<2x[4]xi1> -> vector<2x[4]xf32>
+// CHECK:           vector.create_mask {{.*}} : vector<3x2x[4]xi1>
+// CHECK:           vector.mask {{.*}} { vector.transfer_read %[[OUTPUT]]{{.*}} } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
+/// Two FMAs for dilation=2 (kw=2)
+// CHECK:           vector.fma {{.*}} : vector<3x2x[4]xf32>
+// CHECK:           vector.fma {{.*}} : vector<3x2x[4]xf32>
+// CHECK:           vector.mask {{.*}} { vector.transfer_write {{.*}} } : vector<3x2x[4]xi1>



More information about the Mlir-commits mailing list