[Mlir-commits] [mlir] [mlir][linalg] Decompose conv2d to series of conv1d (PR #169082)
Isaac Nudelman
llvmlistbot at llvm.org
Fri Nov 21 10:43:51 PST 2025
https://github.com/nuudlman created https://github.com/llvm/llvm-project/pull/169082
WIP...
>From 8116510b82639a08fcb8ca94cc299f8fed31b28a Mon Sep 17 00:00:00 2001
From: Isaac Nudelman <isaac.nudelman at utexas.edu>
Date: Fri, 21 Nov 2025 12:37:56 -0600
Subject: [PATCH] Pass to decompose linalg.conv2d_nhwc_hwcf to series of
linalg.conv1d_nwc_wcf's
---
mlir/include/mlir/Dialect/Linalg/Passes.td | 9 +
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../Transforms/DecomposeConv2DToConv1D.cpp | 173 ++++++++++++++++++
3 files changed, 183 insertions(+)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/DecomposeConv2DToConv1D.cpp
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 44da2965e6892..20a7d13f450fe 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -267,4 +267,13 @@ def LinalgBlockPackMatmul : Pass<"linalg-block-pack-matmul"> {
];
}
+def LinalgDecomposeConv2DToConv1D : Pass<"linalg-decompose-conv2d-to-conv1d"> {
+ let summary = "Decompose a conv2d into a series of conv1d ops";
+ let description = [{
+ TODO
+ }];
+
+ let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect", "scf::SCFDialect"];
+}
+
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index fb39e18691e03..da831955e78ce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
ConvertConv2DToImg2Col.cpp
DataLayoutPropagation.cpp
DecomposeLinalgOps.cpp
+ DecomposeConv2DToConv1D.cpp
Detensorize.cpp
DropUnitDims.cpp
ElementwiseOpFusion.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeConv2DToConv1D.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeConv2DToConv1D.cpp
new file mode 100644
index 0000000000000..e02755dd9838b
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeConv2DToConv1D.cpp
@@ -0,0 +1,173 @@
+//===- DecomposeConv2DToConv1D.cpp ---------------------- -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// Converts a conv2d into a series of conv1d ops using row-wise decomposition
+/// (also known as shift-and-add)
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGDECOMPOSECONV2DTOCONV1D
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+
+/// Decomposes a linalg.conv_2d_nhwc_hwcf op into a sequence of
+/// linalg.conv_1d_nwc_wcf ops using a shift-and-add approach.
+///
+/// Constraints:
+/// - Height stride and dilation must be 1 (to allow contiguous reshaping).
+/// - Width stride and dilation are preserved in the 1D convolution.
+struct DecomposeConv2DToConv1DPattern final : public OpRewritePattern<Conv2DNhwcHwcfOp> {
+ using OpRewritePattern<Conv2DNhwcHwcfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(Conv2DNhwcHwcfOp convOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = convOp.getLoc();
+ Value input = convOp.getInputs()[0];
+ Value filter = convOp.getInputs()[1];
+ Value output = convOp.getOutputs()[0];
+
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto filterType = cast<RankedTensorType>(filter.getType());
+
+ // 1. Validate Strides and Dilations
+ // We only support Stride_H = 1 and Dilation_H = 1 for this specific
+ // reshape-based decomposition.
+ auto stridesAttr = convOp.getStrides();
+ auto dilationsAttr = convOp.getDilations();
+
+ SmallVector<int64_t> strides = llvm::to_vector(stridesAttr.getValues<int64_t>());
+ SmallVector<int64_t> dilations = llvm::to_vector(dilationsAttr.getValues<int64_t>());
+
+ if (strides[0] != 1 || dilations[0] != 1) {
+ return rewriter.notifyMatchFailure(convOp, "requires stride_h=1 and dilation_h=1");
+ }
+
+ // 2. Get Dimensions
+ // Input: [N, H, W, C_in]
+ // Filter: [Kh, Kw, C_in, C_out]
+ // Output: [N, H_out, W_out, C_out]
+
+ // Helper to get a Value for a dimension size (static or dynamic)
+ auto getDim = [&](Value v, int64_t idx) -> Value {
+ return tensor::DimOp::create(rewriter, loc, v, idx);
+ };
+
+ Value N = getDim(input, 0);
+ Value H_out = getDim(output, 1);
+ Value W_in = getDim(input, 2);
+ Value C_in = getDim(input, 3);
+
+ Value Kh = getDim(filter, 0);
+ Value Kw = getDim(filter, 1);
+ Value C_out = getDim(filter, 3);
+
+ // 3. Iterate over the Kernel Height (Kh)
+ // We will accumulate results into 'output'.
+ // Lower bound = 0, Upper bound = Kh, Step = 1
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
+
+ auto scfLoop = scf::ForOp::create(rewriter,
+ loc, zero, Kh, one, ValueRange{output},
+ [&](OpBuilder &b, Location loc, Value r, ValueRange args) {
+ Value currentAccumulator = args[0];
+
+ // --- A. Extract Filter Slice ---
+ // Filter shape: [Kh, Kw, Cin, Cout] -> Slice at r: [1, Kw, Cin, Cout]
+ // We need to rank-reduce this to [Kw, Cin, Cout] for conv_1d.
+ SmallVector<OpFoldResult> filterOffsets = {r, b.getIndexAttr(0), b.getIndexAttr(0), b.getIndexAttr(0)};
+ SmallVector<OpFoldResult> filterSizes = {b.getIndexAttr(1), Kw, C_in, C_out};
+ SmallVector<OpFoldResult> filterStrides = {b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1)};
+
+ // Explicitly specify the desired result type (Rank 3)
+ auto filterSliceType = RankedTensorType::get(
+ {ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic},
+ filterType.getElementType());
+
+ Value filterSlice = tensor::ExtractSliceOp::create(b,
+ loc, filterSliceType, filter, filterOffsets, filterSizes, filterStrides);
+
+ // --- B. Extract Input Slice ---
+ // We need a view of the input shifted by 'r' along Height.
+ // Input: [N, H, W, C]. Slice starts at [0, r, 0, 0].
+ // Size: [N, H_out, W, C].
+ // (Recall H_in = H_out + Kh - 1 generally, so H_out fits starting at r).
+ SmallVector<OpFoldResult> inputOffsets = {b.getIndexAttr(0), r, b.getIndexAttr(0), b.getIndexAttr(0)};
+ SmallVector<OpFoldResult> inputSizes = {N, H_out, W_in, C_in};
+ SmallVector<OpFoldResult> inputStrides = {b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1)};
+
+ Value inputSlice = tensor::ExtractSliceOp::create(b,
+ loc, input, inputOffsets, inputSizes, inputStrides);
+
+ // --- C. Reshape Input for Conv1D ---
+ // Conv1D expects [Batch, Width, Channels].
+ // We have [N, H_out, W_in, C_in].
+ // We collapse N and H_out into a single Batch dimension.
+ SmallVector<ReassociationIndices> collapseIndicesInput = {{0, 1}, {2}, {3}};
+ Value reshapedInput = tensor::CollapseShapeOp::create(b,
+ loc, inputSlice, collapseIndicesInput);
+
+ // --- D. Reshape Accumulator for Conv1D ---
+ // Current Accumulator: [N, H_out, W_out, C_out].
+ // Target: [N * H_out, W_out, C_out].
+ Value reshapedAcc = tensor::CollapseShapeOp::create(b,
+ loc, currentAccumulator, collapseIndicesInput);
+
+ // --- E. Perform Conv1D ---
+ // Op: linalg.conv_1d_nwc_wcf
+ // Strides and Dilations for W are passed through from the original Op.
+ // Original Strides: [Stride_H, Stride_W]. We take Stride_W.
+ auto strideW = strides[1];
+ auto dilationW = dilations[1];
+
+ auto conv1d = Conv1DNwcWcfOp::create(b, loc,
+ TypeRange{reshapedAcc.getType()},
+ ValueRange{reshapedInput, filterSlice},
+ ValueRange{reshapedAcc},
+ b.getDenseI64ArrayAttr({strideW}),
+ b.getDenseI64ArrayAttr({dilationW}));
+
+ // --- F. Expand Result back to 4D ---
+ // Result: [N * H_out, W_out, C_out] -> [N, H_out, W_out, C_out]
+ // We use the Type of the currentAccumulator to ensure correct dynamic dim reconstruction.
+ Value expandedResult = tensor::ExpandShapeOp::create(b,
+ loc, currentAccumulator.getType(), conv1d.getResult(0), collapseIndicesInput);
+
+ scf::YieldOp::create(b, loc, expandedResult);
+ });
+
+ rewriter.replaceOp(convOp, scfLoop.getResult(0));
+ return success();
+ }
+};
+
+} // namespace
+
+struct LinalgDecomposeConv2DtoConv1D final : public impl::LinalgDecomposeConv2DToConv1DBase<LinalgDecomposeConv2DtoConv1D> {
+ using Base::Base;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<DecomposeConv2DToConv1DPattern>(&getContext());
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ signalPassFailure();
+ }
+ }
+};
\ No newline at end of file
More information about the Mlir-commits
mailing list