[Mlir-commits] [mlir] [Linalg] Update transform + vectorization patterns to work with generic convolution ops as well (PR #174196)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 2 03:10:45 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Abhishek Varma (Abhishek-Varma)

<details>
<summary>Changes</summary>

-- This commit updates Linalg Transforms and Vectorization patterns
    to work with both named as well as generic convolution ops.
-- This required the following updates to the `isaConvolutionOfType` API :-
    1. Allow dilations/strides to be optional arguments.
    2. Populate dilations/strides info for named convolution ops as well.
    3. Since now a "generic" LinalgOp is being used as the root op in the patterns
   above the `assert` of the op implementing a ConvolutionOpInterface has
   been replaced with an early exit `if`.

Signed-off-by: Abhishek Varma <abhvarma@<!-- -->amd.com>

---

Patch is 114.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174196.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+13-14) 
- (modified) mlir/include/mlir/Dialect/Linalg/Utils/Utils.h (+5-3) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+56-51) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+71-25) 
- (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+537-149) 
- (modified) mlir/test/Dialect/Linalg/transform-op-decompose.mlir (+159-79) 
- (modified) mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir (+222-1) 
- (modified) mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir (+198) 
- (modified) mlir/test/Dialect/Linalg/vectorization/convolution.mlir (+147) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6678d693719bf..32067358438d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1641,16 +1641,16 @@ FailureOr<linalg::GenericOp> deduplicateOperandsAndRemoveDeadResults(
 //===----------------------------------------------------------------------===//
 
 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
-/// convolution ops.
+/// convolution ops. Works with both named ops and equivalent generic ops.
 template <typename Conv2DOp, typename Conv1DOp>
 struct DownscaleSizeOneWindowed2DConvolution final
-    : public OpRewritePattern<Conv2DOp> {
-  using OpRewritePattern<Conv2DOp>::OpRewritePattern;
+    : public OpInterfaceRewritePattern<LinalgOp> {
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
 
-  FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
+  FailureOr<Conv1DOp> returningMatchAndRewrite(LinalgOp convOp,
                                                PatternRewriter &rewriter) const;
 
-  LogicalResult matchAndRewrite(Conv2DOp convOp,
+  LogicalResult matchAndRewrite(LinalgOp convOp,
                                 PatternRewriter &rewriter) const override {
     return returningMatchAndRewrite(convOp, rewriter);
   }
@@ -1664,29 +1664,28 @@ extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
 /// dimensions into 1-D depthwise convolution ops.
 struct DownscaleDepthwiseConv2DNhwcHwcOp final
-    : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
+    : public OpInterfaceRewritePattern<LinalgOp> {
   DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context,
                                     PatternBenefit benefit = 1)
-      : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit) {}
+      : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
 
   FailureOr<DepthwiseConv1DNwcWcOp>
-  returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
-                           PatternRewriter &rewriter) const;
+  returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const;
 
-  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
+  LogicalResult matchAndRewrite(LinalgOp convOp,
                                 PatternRewriter &rewriter) const override {
     return returningMatchAndRewrite(convOp, rewriter);
   }
 };
 
-struct DownscaleConv2DOp final : public OpRewritePattern<Conv2DOp> {
+struct DownscaleConv2DOp final : public OpInterfaceRewritePattern<LinalgOp> {
   DownscaleConv2DOp(MLIRContext *context, PatternBenefit benefit = 1)
-      : OpRewritePattern<Conv2DOp>(context, benefit) {}
+      : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
 
-  FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
+  FailureOr<Conv1DOp> returningMatchAndRewrite(LinalgOp convOp,
                                                PatternRewriter &rewriter) const;
 
-  LogicalResult matchAndRewrite(Conv2DOp convOp,
+  LogicalResult matchAndRewrite(LinalgOp convOp,
                                 PatternRewriter &rewriter) const override {
     return returningMatchAndRewrite(convOp, rewriter);
   }
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 9da01f30b52d2..16d557a6ed7fa 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -108,10 +108,12 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
 
 /// Given a linalg `op` this function returns true if it is a convolution op of
 /// type `ConvOpTy` and populates `dilations` and `strides` with values inferred
-/// from the indexing maps.
+/// from the indexing maps. If `dilations` or `strides` is nullptr, the
+/// corresponding values are not populated.
 template <typename ConvOpTy>
-bool isaConvolutionOpOfType(LinalgOp op, SmallVector<int64_t> *dilations,
-                            SmallVector<int64_t> *strides);
+bool isaConvolutionOpOfType(LinalgOp op,
+                            SmallVector<int64_t> *dilations = nullptr,
+                            SmallVector<int64_t> *strides = nullptr);
 
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 96cc378f6c21a..7972408318b95 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -32,6 +32,7 @@
 #include "llvm/Support/DebugLog.h"
 #include "llvm/Support/InterleavedRange.h"
 #include "llvm/Support/raw_ostream.h"
+#include <type_traits>
 #include <utility>
 
 #define DEBUG_TYPE "linalg-transforms"
@@ -1406,13 +1407,18 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
 
 template <typename Conv2DOp, typename Conv1DOp>
 FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
-    returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
+    returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const {
+  // Check if this LinalgOp is of the expected Conv2DOp type (named or generic).
+  SmallVector<int64_t> dilations, strides;
+  if (!linalg::isaConvolutionOpOfType<Conv2DOp>(convOp, &dilations, &strides))
+    return failure();
+
   if (convOp.hasPureBufferSemantics())
     return failure(); // To be implemented.
 
-  Value input = convOp.getInputs().front();
-  Value kernel = convOp.getInputs().back();
-  Value output = convOp.getOutputs().front();
+  Value input = convOp.getDpsInputs().front();
+  Value kernel = convOp.getDpsInputs().back();
+  Value output = convOp.getDpsInits().front();
 
   auto inputType = dyn_cast<RankedTensorType>(input.getType());
   auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
@@ -1421,38 +1427,33 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
   auto kernelShape = kernelType.getShape();
   auto outputShape = outputType.getShape();
 
-  // Get domain indices based on conv2D layout.
-  auto [khIndex, kwIndex, ohIndex, owIndex] =
-      TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>(
-          convOp)
-          .Case([&](linalg::Conv2DNhwcHwcfOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::Conv2DNchwFchwOp op) {
-            return std::make_tuple(2, 3, 2, 3);
-          })
-          .Case([&](linalg::PoolingNhwcSumOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::PoolingNchwSumOp op) {
-            return std::make_tuple(0, 1, 2, 3);
-          })
-          .Case([&](linalg::PoolingNhwcMaxOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::PoolingNhwcMinOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
-            return std::make_tuple(0, 1, 1, 2);
-          })
-          .Case([&](linalg::PoolingNchwMaxOp op) {
-            return std::make_tuple(0, 1, 2, 3);
-          })
-          .DefaultUnreachable("unexpected conv2d/pool2d operation.");
+  // Get domain indices based on Conv2DOp type. These are known at compile time.
+  int64_t khIndex, kwIndex, ohIndex, owIndex;
+  if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNhwcHwcfOp> ||
+                std::is_same_v<Conv2DOp, linalg::PoolingNhwcSumOp> ||
+                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxOp> ||
+                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxUnsignedOp> ||
+                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinOp> ||
+                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinUnsignedOp>) {
+    // NHWC layout: kernel [H, W, ...], output [N, H, W, C]
+    khIndex = 0;
+    kwIndex = 1;
+    ohIndex = 1;
+    owIndex = 2;
+  } else if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNchwFchwOp>) {
+    // NCHW_FCHW layout: kernel [..., H, W], output [N, C, H, W]
+    khIndex = 2;
+    kwIndex = 3;
+    ohIndex = 2;
+    owIndex = 3;
+  } else if constexpr (std::is_same_v<Conv2DOp, linalg::PoolingNchwSumOp> ||
+                       std::is_same_v<Conv2DOp, linalg::PoolingNchwMaxOp>) {
+    // NCHW pooling layout: kernel [H, W], output [N, C, H, W]
+    khIndex = 0;
+    kwIndex = 1;
+    ohIndex = 2;
+    owIndex = 3;
+  }
 
   // Only handle the case where at least one of the window dimensions is
   // of size 1. Other cases can rely on tiling to reduce to such cases.
@@ -1484,13 +1485,9 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
 
   // Rank-reduce strides and dilations too.
   // TODO: dropDim 1-liner helper.
-  auto strides =
-      llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
   strides.erase(strides.begin() + (removeH ? 0 : 1));
   auto stridesAttr = rewriter.getI64VectorAttr(strides);
 
-  auto dilations =
-      llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
 
@@ -1527,13 +1524,19 @@ template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
 
 FailureOr<DepthwiseConv1DNwcWcOp>
 DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
-    DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
+    LinalgOp convOp, PatternRewriter &rewriter) const {
+  // Check if this LinalgOp is a DepthwiseConv2DNhwcHwcOp (named or generic).
+  SmallVector<int64_t> dilations, strides;
+  if (!linalg::isaConvolutionOpOfType<DepthwiseConv2DNhwcHwcOp>(
+          convOp, &dilations, &strides))
+    return failure();
+
   if (convOp.hasPureBufferSemantics())
     return failure(); // To be implemented.
 
-  Value input = convOp.getInputs().front();
-  Value kernel = convOp.getInputs().back();
-  Value output = convOp.getOutputs().front();
+  Value input = convOp.getDpsInputs().front();
+  Value kernel = convOp.getDpsInputs().back();
+  Value output = convOp.getDpsInits().front();
 
   auto inputType = dyn_cast<RankedTensorType>(input.getType());
   auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
@@ -1572,12 +1575,9 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
 
   // Rank-reduce strides and dilations too.
   // TODO: dropDim 1-liner helper.
-  auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
   strides.erase(strides.begin() + (removeH ? 0 : 1));
   auto stridesAttr = rewriter.getI64VectorAttr(strides);
 
-  auto dilations =
-      llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
 
@@ -1594,14 +1594,19 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
 }
 
 FailureOr<Conv1DOp>
-DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
+DownscaleConv2DOp::returningMatchAndRewrite(LinalgOp convOp,
                                             PatternRewriter &rewriter) const {
+  // Check if this LinalgOp is a Conv2DOp (named or generic).
+  SmallVector<int64_t> dilations, strides;
+  if (!linalg::isaConvolutionOpOfType<Conv2DOp>(convOp, &dilations, &strides))
+    return failure();
+
   if (convOp.hasPureBufferSemantics())
     return failure(); // To be implemented.
 
-  Value input = convOp.getInputs().front();
-  Value kernel = convOp.getInputs().back();
-  Value output = convOp.getOutputs().front();
+  Value input = convOp.getDpsInputs().front();
+  Value kernel = convOp.getDpsInputs().back();
+  Value output = convOp.getDpsInits().front();
 
   auto inputType = dyn_cast<RankedTensorType>(input.getType());
   auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bb3bccdae0e14..0f9a7a1751699 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -2070,7 +2071,7 @@ vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
     return failure();
   }
 
-  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
+  if (!isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
     LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
     return failure();
   }
@@ -2431,10 +2432,11 @@ static LogicalResult vectorizeLinalgOpPrecondition(
   if (isElementwise(linalgOp))
     return success();
 
-  // TODO: isaConvolutionOpInterface that can also infer from generic
-  // features. But we will still need stride/dilation attributes that will be
-  // annoying to reverse-engineer...
-  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
+  // Check for convolution ops - both named ops implementing
+  // ConvolutionOpInterface and generic ops that semantically match convolution
+  // patterns.
+  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
+      isaConvolutionOpInterface(linalgOp))
     return vectorizeConvOpPrecondition(linalgOp);
 
   // TODO: the common vector shape is equal to the static loop sizes only when
@@ -2639,11 +2641,11 @@ vectorizeScalableVectorPrecondition(Operation *op,
 
   // Cond 4: Only the following ops are supported in the
   // presence of scalable vectors
-  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
-                 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
-                 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
-                 isa<linalg::BatchMmt4DOp>(op) ||
-                 hasReductionIterator(linalgOp));
+  return success(
+      isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
+      isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(linalgOp) ||
+      isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+      isa<linalg::BatchMmt4DOp>(op) || hasReductionIterator(linalgOp));
 }
 
 LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2734,7 +2736,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
             // TODO: isaConvolutionOpInterface that can also infer from
             // generic features. Will require stride/dilation attributes
             // inference.
-            if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+            if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
+                isaConvolutionOpInterface(linalgOp)) {
               FailureOr<Operation *> convOr = vectorizeConvolution(
                   rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
                   flatten1DDepthwiseConv);
@@ -3480,6 +3483,43 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
   bindShapeDims<0>(shapedType, vals...);
 }
 
+/// Helper to extract strides and dilations for 1D convolution/pooling ops.
+/// Returns true if the op is a recognized 1D conv/pool op and extracts the
+/// stride and dilation values. For unrecognized ops, returns false.
+static bool extract1DConvPoolStrideDilation(LinalgOp op, int &strideW,
+                                            int &dilationW) {
+  SmallVector<int64_t> dilations, strides;
+
+#define EXTRACT_1D_CONV_POOL_STRIDE_DILATION(ConvOpTy)                         \
+  if (isaConvolutionOpOfType<ConvOpTy>(op, &dilations, &strides)) {            \
+    strideW = static_cast<int>(strides.front());                               \
+    dilationW = static_cast<int>(dilations.front());                           \
+    return true;                                                               \
+  }
+
+  // 1D Convolution ops
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNwcWcfOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNcwFcwOp);
+  // Depthwise 1D Convolution ops
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNcwCwOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcmOp);
+  // 1D Pooling ops (NWC layout)
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcSumOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxUnsignedOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinUnsignedOp);
+  // 1D Pooling ops (NCW layout)
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwSumOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwMaxOp);
+
+#undef EXTRACT_1D_CONV_POOL_STRIDE_DILATION
+
+  return false;
+}
+
 namespace {
 /// Generate a vector implementation for either:
 /// ```
@@ -3535,14 +3575,19 @@ struct Conv1DGenerator
     auto maybeKind = getCombinerOpKind(reduceOp);
     reductionKind = maybeKind.value();
 
-    // The ConvolutionOpInterface gives us guarantees of existence for
-    // strides/dilations. However, we do not need to rely on those, we can
-    // simply use them if present, otherwise use the default and let the generic
-    // conv. matcher in the ConvGenerator succeed or fail.
-    auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
-    auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
-    strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
-    dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+    // Try to extract strides/dilations from named 1D conv/pool ops using
+    // isaConvolutionOpOfType. This works for both named ops and generic ops
+    // that match their semantics. For unrecognized generic ops, fall back to
+    // checking attributes directly (which may not exist for generic ops).
+    if (!extract1DConvPoolStrideDilation(linalgOp, strideW, dilationW)) {
+      // Fallback: check for stride/dilation attributes directly.
+      // For generic ops without these attributes, default to 1.
+      auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
+      auto dilations =
+          linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+      strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
+      dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+    }
   }
 
   /// Generate a vector implementation for:
@@ -4265,13 +4310,14 @@ static FailureOr<Operation *> vectorizeConvolution(
   if (!inputVecSizes.empty()) {
     // Only use the input vector size corresponding to the channel dim. Other
     // vector dims will be inferred from the Ops.
-    assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
-            isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
+    assert((isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(op) ||
+            isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(op)) &&
            "Not a 1D depthwise conv!");
-    size_t chDimIdx =
-        TypeSwitch<Operation *, size_t>(op)
-            .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
-            .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
+    size_t chDimIdx = 0;
+    if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(op))
+      chDimIdx = 2;
+    else if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(op))
+      chDimIdx = 1;
 
     vecChDimSize = inputVecSizes[chDimIdx];
     vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Util...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/174196


More information about the Mlir-commits mailing list