[Mlir-commits] [mlir] [Linalg] Update transform + vectorization patterns to work with generic convolution ops as well (PR #174196)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 2 03:10:45 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Abhishek Varma (Abhishek-Varma)
<details>
<summary>Changes</summary>
-- This commit updates Linalg Transforms and Vectorization patterns
to work with both named as well as generic convolution ops.
-- This required the following updates to the `isaConvolutionOfType` API :-
1. Allow dilations/strides to be optional arguments.
2. Populate dilations/strides info for named convolution ops as well.
3. Since now a "generic" LinalgOp is being used as the root op in the patterns
above the `assert` of the op implementing a ConvolutionOpInterface has
been replaced with an early exit `if`.
Signed-off-by: Abhishek Varma <abhvarma@<!-- -->amd.com>
---
Patch is 114.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174196.diff
9 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+13-14)
- (modified) mlir/include/mlir/Dialect/Linalg/Utils/Utils.h (+5-3)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+56-51)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+71-25)
- (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+537-149)
- (modified) mlir/test/Dialect/Linalg/transform-op-decompose.mlir (+159-79)
- (modified) mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir (+222-1)
- (modified) mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir (+198)
- (modified) mlir/test/Dialect/Linalg/vectorization/convolution.mlir (+147)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6678d693719bf..32067358438d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1641,16 +1641,16 @@ FailureOr<linalg::GenericOp> deduplicateOperandsAndRemoveDeadResults(
//===----------------------------------------------------------------------===//
/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
-/// convolution ops.
+/// convolution ops. Works with both named ops and equivalent generic ops.
template <typename Conv2DOp, typename Conv1DOp>
struct DownscaleSizeOneWindowed2DConvolution final
- : public OpRewritePattern<Conv2DOp> {
- using OpRewritePattern<Conv2DOp>::OpRewritePattern;
+ : public OpInterfaceRewritePattern<LinalgOp> {
+ using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
- FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
+ FailureOr<Conv1DOp> returningMatchAndRewrite(LinalgOp convOp,
PatternRewriter &rewriter) const;
- LogicalResult matchAndRewrite(Conv2DOp convOp,
+ LogicalResult matchAndRewrite(LinalgOp convOp,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(convOp, rewriter);
}
@@ -1664,29 +1664,28 @@ extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
/// dimensions into 1-D depthwise convolution ops.
struct DownscaleDepthwiseConv2DNhwcHwcOp final
- : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
+ : public OpInterfaceRewritePattern<LinalgOp> {
DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context,
PatternBenefit benefit = 1)
- : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit) {}
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
FailureOr<DepthwiseConv1DNwcWcOp>
- returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
- PatternRewriter &rewriter) const;
+ returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const;
- LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
+ LogicalResult matchAndRewrite(LinalgOp convOp,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(convOp, rewriter);
}
};
-struct DownscaleConv2DOp final : public OpRewritePattern<Conv2DOp> {
+struct DownscaleConv2DOp final : public OpInterfaceRewritePattern<LinalgOp> {
DownscaleConv2DOp(MLIRContext *context, PatternBenefit benefit = 1)
- : OpRewritePattern<Conv2DOp>(context, benefit) {}
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
- FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
+ FailureOr<Conv1DOp> returningMatchAndRewrite(LinalgOp convOp,
PatternRewriter &rewriter) const;
- LogicalResult matchAndRewrite(Conv2DOp convOp,
+ LogicalResult matchAndRewrite(LinalgOp convOp,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(convOp, rewriter);
}
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 9da01f30b52d2..16d557a6ed7fa 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -108,10 +108,12 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
/// Given a linalg `op` this function returns true if it is a convolution op of
/// type `ConvOpTy` and populates `dilations` and `strides` with values inferred
-/// from the indexing maps.
+/// from the indexing maps. If `dilations` or `strides` is nullptr, the
+/// corresponding values are not populated.
template <typename ConvOpTy>
-bool isaConvolutionOpOfType(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides);
+bool isaConvolutionOpOfType(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 96cc378f6c21a..7972408318b95 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -32,6 +32,7 @@
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/raw_ostream.h"
+#include <type_traits>
#include <utility>
#define DEBUG_TYPE "linalg-transforms"
@@ -1406,13 +1407,18 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
template <typename Conv2DOp, typename Conv1DOp>
FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
- returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
+ returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const {
+ // Check if this LinalgOp is of the expected Conv2DOp type (named or generic).
+ SmallVector<int64_t> dilations, strides;
+ if (!linalg::isaConvolutionOpOfType<Conv2DOp>(convOp, &dilations, &strides))
+ return failure();
+
if (convOp.hasPureBufferSemantics())
return failure(); // To be implemented.
- Value input = convOp.getInputs().front();
- Value kernel = convOp.getInputs().back();
- Value output = convOp.getOutputs().front();
+ Value input = convOp.getDpsInputs().front();
+ Value kernel = convOp.getDpsInputs().back();
+ Value output = convOp.getDpsInits().front();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
@@ -1421,38 +1427,33 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
auto kernelShape = kernelType.getShape();
auto outputShape = outputType.getShape();
- // Get domain indices based on conv2D layout.
- auto [khIndex, kwIndex, ohIndex, owIndex] =
- TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>(
- convOp)
- .Case([&](linalg::Conv2DNhwcHwcfOp op) {
- return std::make_tuple(0, 1, 1, 2);
- })
- .Case([&](linalg::Conv2DNchwFchwOp op) {
- return std::make_tuple(2, 3, 2, 3);
- })
- .Case([&](linalg::PoolingNhwcSumOp op) {
- return std::make_tuple(0, 1, 1, 2);
- })
- .Case([&](linalg::PoolingNchwSumOp op) {
- return std::make_tuple(0, 1, 2, 3);
- })
- .Case([&](linalg::PoolingNhwcMaxOp op) {
- return std::make_tuple(0, 1, 1, 2);
- })
- .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
- return std::make_tuple(0, 1, 1, 2);
- })
- .Case([&](linalg::PoolingNhwcMinOp op) {
- return std::make_tuple(0, 1, 1, 2);
- })
- .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
- return std::make_tuple(0, 1, 1, 2);
- })
- .Case([&](linalg::PoolingNchwMaxOp op) {
- return std::make_tuple(0, 1, 2, 3);
- })
- .DefaultUnreachable("unexpected conv2d/pool2d operation.");
+ // Get domain indices based on Conv2DOp type. These are known at compile time.
+ int64_t khIndex, kwIndex, ohIndex, owIndex;
+ if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNhwcHwcfOp> ||
+ std::is_same_v<Conv2DOp, linalg::PoolingNhwcSumOp> ||
+ std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxOp> ||
+ std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxUnsignedOp> ||
+ std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinOp> ||
+ std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinUnsignedOp>) {
+ // NHWC layout: kernel [H, W, ...], output [N, H, W, C]
+ khIndex = 0;
+ kwIndex = 1;
+ ohIndex = 1;
+ owIndex = 2;
+ } else if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNchwFchwOp>) {
+ // NCHW_FCHW layout: kernel [..., H, W], output [N, C, H, W]
+ khIndex = 2;
+ kwIndex = 3;
+ ohIndex = 2;
+ owIndex = 3;
+ } else if constexpr (std::is_same_v<Conv2DOp, linalg::PoolingNchwSumOp> ||
+ std::is_same_v<Conv2DOp, linalg::PoolingNchwMaxOp>) {
+ // NCHW pooling layout: kernel [H, W], output [N, C, H, W]
+ khIndex = 0;
+ kwIndex = 1;
+ ohIndex = 2;
+ owIndex = 3;
+ }
// Only handle the case where at least one of the window dimensions is
// of size 1. Other cases can rely on tiling to reduce to such cases.
@@ -1484,13 +1485,9 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
// Rank-reduce strides and dilations too.
// TODO: dropDim 1-liner helper.
- auto strides =
- llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
strides.erase(strides.begin() + (removeH ? 0 : 1));
auto stridesAttr = rewriter.getI64VectorAttr(strides);
- auto dilations =
- llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
dilations.erase(dilations.begin() + (removeH ? 0 : 1));
auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
@@ -1527,13 +1524,19 @@ template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
FailureOr<DepthwiseConv1DNwcWcOp>
DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
- DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
+ LinalgOp convOp, PatternRewriter &rewriter) const {
+ // Check if this LinalgOp is a DepthwiseConv2DNhwcHwcOp (named or generic).
+ SmallVector<int64_t> dilations, strides;
+ if (!linalg::isaConvolutionOpOfType<DepthwiseConv2DNhwcHwcOp>(
+ convOp, &dilations, &strides))
+ return failure();
+
if (convOp.hasPureBufferSemantics())
return failure(); // To be implemented.
- Value input = convOp.getInputs().front();
- Value kernel = convOp.getInputs().back();
- Value output = convOp.getOutputs().front();
+ Value input = convOp.getDpsInputs().front();
+ Value kernel = convOp.getDpsInputs().back();
+ Value output = convOp.getDpsInits().front();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
@@ -1572,12 +1575,9 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
// Rank-reduce strides and dilations too.
// TODO: dropDim 1-liner helper.
- auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
strides.erase(strides.begin() + (removeH ? 0 : 1));
auto stridesAttr = rewriter.getI64VectorAttr(strides);
- auto dilations =
- llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
dilations.erase(dilations.begin() + (removeH ? 0 : 1));
auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
@@ -1594,14 +1594,19 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
}
FailureOr<Conv1DOp>
-DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
+DownscaleConv2DOp::returningMatchAndRewrite(LinalgOp convOp,
PatternRewriter &rewriter) const {
+ // Check if this LinalgOp is a Conv2DOp (named or generic).
+ SmallVector<int64_t> dilations, strides;
+ if (!linalg::isaConvolutionOpOfType<Conv2DOp>(convOp, &dilations, &strides))
+ return failure();
+
if (convOp.hasPureBufferSemantics())
return failure(); // To be implemented.
- Value input = convOp.getInputs().front();
- Value kernel = convOp.getInputs().back();
- Value output = convOp.getOutputs().front();
+ Value input = convOp.getDpsInputs().front();
+ Value kernel = convOp.getDpsInputs().back();
+ Value output = convOp.getDpsInits().front();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bb3bccdae0e14..0f9a7a1751699 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -2070,7 +2071,7 @@ vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
return failure();
}
- if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
+ if (!isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
return failure();
}
@@ -2431,10 +2432,11 @@ static LogicalResult vectorizeLinalgOpPrecondition(
if (isElementwise(linalgOp))
return success();
- // TODO: isaConvolutionOpInterface that can also infer from generic
- // features. But we will still need stride/dilation attributes that will be
- // annoying to reverse-engineer...
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
+ // Check for convolution ops - both named ops implementing
+ // ConvolutionOpInterface and generic ops that semantically match convolution
+ // patterns.
+ if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
+ isaConvolutionOpInterface(linalgOp))
return vectorizeConvOpPrecondition(linalgOp);
// TODO: the common vector shape is equal to the static loop sizes only when
@@ -2639,11 +2641,11 @@ vectorizeScalableVectorPrecondition(Operation *op,
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
- return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
- isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
- isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
- isa<linalg::BatchMmt4DOp>(op) ||
- hasReductionIterator(linalgOp));
+ return success(
+ isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
+ isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(linalgOp) ||
+ isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+ isa<linalg::BatchMmt4DOp>(op) || hasReductionIterator(linalgOp));
}
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2734,7 +2736,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
// TODO: isaConvolutionOpInterface that can also infer from
// generic features. Will require stride/dilation attributes
// inference.
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+ if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
+ isaConvolutionOpInterface(linalgOp)) {
FailureOr<Operation *> convOr = vectorizeConvolution(
rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
flatten1DDepthwiseConv);
@@ -3480,6 +3483,43 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
bindShapeDims<0>(shapedType, vals...);
}
+/// Helper to extract strides and dilations for 1D convolution/pooling ops.
+/// Returns true if the op is a recognized 1D conv/pool op and extracts the
+/// stride and dilation values. For unrecognized ops, returns false.
+static bool extract1DConvPoolStrideDilation(LinalgOp op, int &strideW,
+ int &dilationW) {
+ SmallVector<int64_t> dilations, strides;
+
+#define EXTRACT_1D_CONV_POOL_STRIDE_DILATION(ConvOpTy) \
+ if (isaConvolutionOpOfType<ConvOpTy>(op, &dilations, &strides)) { \
+ strideW = static_cast<int>(strides.front()); \
+ dilationW = static_cast<int>(dilations.front()); \
+ return true; \
+ }
+
+ // 1D Convolution ops
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNwcWcfOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNcwFcwOp);
+ // Depthwise 1D Convolution ops
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNcwCwOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcmOp);
+ // 1D Pooling ops (NWC layout)
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcSumOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxUnsignedOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinUnsignedOp);
+ // 1D Pooling ops (NCW layout)
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwSumOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwMaxOp);
+
+#undef EXTRACT_1D_CONV_POOL_STRIDE_DILATION
+
+ return false;
+}
+
namespace {
/// Generate a vector implementation for either:
/// ```
@@ -3535,14 +3575,19 @@ struct Conv1DGenerator
auto maybeKind = getCombinerOpKind(reduceOp);
reductionKind = maybeKind.value();
- // The ConvolutionOpInterface gives us guarantees of existence for
- // strides/dilations. However, we do not need to rely on those, we can
- // simply use them if present, otherwise use the default and let the generic
- // conv. matcher in the ConvGenerator succeed or fail.
- auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
- auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
- strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
- dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+ // Try to extract strides/dilations from named 1D conv/pool ops using
+ // isaConvolutionOpOfType. This works for both named ops and generic ops
+ // that match their semantics. For unrecognized generic ops, fall back to
+ // checking attributes directly (which may not exist for generic ops).
+ if (!extract1DConvPoolStrideDilation(linalgOp, strideW, dilationW)) {
+ // Fallback: check for stride/dilation attributes directly.
+ // For generic ops without these attributes, default to 1.
+ auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
+ auto dilations =
+ linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+ strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
+ dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+ }
}
/// Generate a vector implementation for:
@@ -4265,13 +4310,14 @@ static FailureOr<Operation *> vectorizeConvolution(
if (!inputVecSizes.empty()) {
// Only use the input vector size corresponding to the channel dim. Other
// vector dims will be inferred from the Ops.
- assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
- isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
+ assert((isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(op) ||
+ isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(op)) &&
"Not a 1D depthwise conv!");
- size_t chDimIdx =
- TypeSwitch<Operation *, size_t>(op)
- .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
- .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
+ size_t chDimIdx = 0;
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(op))
+ chDimIdx = 2;
+ else if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(op))
+ chDimIdx = 1;
vecChDimSize = inputVecSizes[chDimIdx];
vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Util...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/174196
More information about the Mlir-commits
mailing list