[Mlir-commits] [mlir] [mlir][linalg] Decompose conv2d to series of conv1d (PR #169082)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 21 10:45:36 PST 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff origin/main HEAD --extensions cpp -- mlir/lib/Dialect/Linalg/Transforms/DecomposeConv2DToConv1D.cpp --diff_from_common_commit
``````````
:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeConv2DToConv1D.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeConv2DToConv1D.cpp
index e02755dd9..64ad41fd2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeConv2DToConv1D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeConv2DToConv1D.cpp
@@ -33,7 +33,8 @@ namespace {
/// Constraints:
/// - Height stride and dilation must be 1 (to allow contiguous reshaping).
/// - Width stride and dilation are preserved in the 1D convolution.
-struct DecomposeConv2DToConv1DPattern final : public OpRewritePattern<Conv2DNhwcHwcfOp> {
+struct DecomposeConv2DToConv1DPattern final
+ : public OpRewritePattern<Conv2DNhwcHwcfOp> {
using OpRewritePattern<Conv2DNhwcHwcfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(Conv2DNhwcHwcfOp convOp,
@@ -52,11 +53,14 @@ struct DecomposeConv2DToConv1DPattern final : public OpRewritePattern<Conv2DNhwc
auto stridesAttr = convOp.getStrides();
auto dilationsAttr = convOp.getDilations();
- SmallVector<int64_t> strides = llvm::to_vector(stridesAttr.getValues<int64_t>());
- SmallVector<int64_t> dilations = llvm::to_vector(dilationsAttr.getValues<int64_t>());
+ SmallVector<int64_t> strides =
+ llvm::to_vector(stridesAttr.getValues<int64_t>());
+ SmallVector<int64_t> dilations =
+ llvm::to_vector(dilationsAttr.getValues<int64_t>());
if (strides[0] != 1 || dilations[0] != 1) {
- return rewriter.notifyMatchFailure(convOp, "requires stride_h=1 and dilation_h=1");
+ return rewriter.notifyMatchFailure(
+ convOp, "requires stride_h=1 and dilation_h=1");
}
// 2. Get Dimensions
@@ -84,71 +88,83 @@ struct DecomposeConv2DToConv1DPattern final : public OpRewritePattern<Conv2DNhwc
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
- auto scfLoop = scf::ForOp::create(rewriter,
- loc, zero, Kh, one, ValueRange{output},
+ auto scfLoop = scf::ForOp::create(
+ rewriter, loc, zero, Kh, one, ValueRange{output},
[&](OpBuilder &b, Location loc, Value r, ValueRange args) {
Value currentAccumulator = args[0];
// --- A. Extract Filter Slice ---
// Filter shape: [Kh, Kw, Cin, Cout] -> Slice at r: [1, Kw, Cin, Cout]
// We need to rank-reduce this to [Kw, Cin, Cout] for conv_1d.
- SmallVector<OpFoldResult> filterOffsets = {r, b.getIndexAttr(0), b.getIndexAttr(0), b.getIndexAttr(0)};
- SmallVector<OpFoldResult> filterSizes = {b.getIndexAttr(1), Kw, C_in, C_out};
- SmallVector<OpFoldResult> filterStrides = {b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1)};
+ SmallVector<OpFoldResult> filterOffsets = {
+ r, b.getIndexAttr(0), b.getIndexAttr(0), b.getIndexAttr(0)};
+ SmallVector<OpFoldResult> filterSizes = {b.getIndexAttr(1), Kw, C_in,
+ C_out};
+ SmallVector<OpFoldResult> filterStrides = {
+ b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1),
+ b.getIndexAttr(1)};
// Explicitly specify the desired result type (Rank 3)
- auto filterSliceType = RankedTensorType::get(
- {ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic},
- filterType.getElementType());
+ auto filterSliceType =
+ RankedTensorType::get({ShapedType::kDynamic, ShapedType::kDynamic,
+ ShapedType::kDynamic},
+ filterType.getElementType());
- Value filterSlice = tensor::ExtractSliceOp::create(b,
- loc, filterSliceType, filter, filterOffsets, filterSizes, filterStrides);
+ Value filterSlice = tensor::ExtractSliceOp::create(
+ b, loc, filterSliceType, filter, filterOffsets, filterSizes,
+ filterStrides);
// --- B. Extract Input Slice ---
// We need a view of the input shifted by 'r' along Height.
// Input: [N, H, W, C]. Slice starts at [0, r, 0, 0].
// Size: [N, H_out, W, C].
- // (Recall H_in = H_out + Kh - 1 generally, so H_out fits starting at r).
- SmallVector<OpFoldResult> inputOffsets = {b.getIndexAttr(0), r, b.getIndexAttr(0), b.getIndexAttr(0)};
+ // (Recall H_in = H_out + Kh - 1 generally, so H_out fits starting at
+ // r).
+ SmallVector<OpFoldResult> inputOffsets = {
+ b.getIndexAttr(0), r, b.getIndexAttr(0), b.getIndexAttr(0)};
SmallVector<OpFoldResult> inputSizes = {N, H_out, W_in, C_in};
- SmallVector<OpFoldResult> inputStrides = {b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1)};
+ SmallVector<OpFoldResult> inputStrides = {
+ b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1),
+ b.getIndexAttr(1)};
- Value inputSlice = tensor::ExtractSliceOp::create(b,
- loc, input, inputOffsets, inputSizes, inputStrides);
+ Value inputSlice = tensor::ExtractSliceOp::create(
+ b, loc, input, inputOffsets, inputSizes, inputStrides);
// --- C. Reshape Input for Conv1D ---
// Conv1D expects [Batch, Width, Channels].
// We have [N, H_out, W_in, C_in].
// We collapse N and H_out into a single Batch dimension.
- SmallVector<ReassociationIndices> collapseIndicesInput = {{0, 1}, {2}, {3}};
- Value reshapedInput = tensor::CollapseShapeOp::create(b,
- loc, inputSlice, collapseIndicesInput);
+ SmallVector<ReassociationIndices> collapseIndicesInput = {
+ {0, 1}, {2}, {3}};
+ Value reshapedInput = tensor::CollapseShapeOp::create(
+ b, loc, inputSlice, collapseIndicesInput);
// --- D. Reshape Accumulator for Conv1D ---
// Current Accumulator: [N, H_out, W_out, C_out].
// Target: [N * H_out, W_out, C_out].
- Value reshapedAcc = tensor::CollapseShapeOp::create(b,
- loc, currentAccumulator, collapseIndicesInput);
+ Value reshapedAcc = tensor::CollapseShapeOp::create(
+ b, loc, currentAccumulator, collapseIndicesInput);
// --- E. Perform Conv1D ---
// Op: linalg.conv_1d_nwc_wcf
- // Strides and Dilations for W are passed through from the original Op.
- // Original Strides: [Stride_H, Stride_W]. We take Stride_W.
+ // Strides and Dilations for W are passed through from the original
+ // Op. Original Strides: [Stride_H, Stride_W]. We take Stride_W.
auto strideW = strides[1];
auto dilationW = dilations[1];
- auto conv1d = Conv1DNwcWcfOp::create(b, loc,
- TypeRange{reshapedAcc.getType()},
- ValueRange{reshapedInput, filterSlice},
- ValueRange{reshapedAcc},
+ auto conv1d = Conv1DNwcWcfOp::create(
+ b, loc, TypeRange{reshapedAcc.getType()},
+ ValueRange{reshapedInput, filterSlice}, ValueRange{reshapedAcc},
b.getDenseI64ArrayAttr({strideW}),
b.getDenseI64ArrayAttr({dilationW}));
// --- F. Expand Result back to 4D ---
// Result: [N * H_out, W_out, C_out] -> [N, H_out, W_out, C_out]
- // We use the Type of the currentAccumulator to ensure correct dynamic dim reconstruction.
- Value expandedResult = tensor::ExpandShapeOp::create(b,
- loc, currentAccumulator.getType(), conv1d.getResult(0), collapseIndicesInput);
+ // We use the Type of the currentAccumulator to ensure correct dynamic
+ // dim reconstruction.
+ Value expandedResult = tensor::ExpandShapeOp::create(
+ b, loc, currentAccumulator.getType(), conv1d.getResult(0),
+ collapseIndicesInput);
scf::YieldOp::create(b, loc, expandedResult);
});
@@ -160,7 +176,9 @@ struct DecomposeConv2DToConv1DPattern final : public OpRewritePattern<Conv2DNhwc
} // namespace
-struct LinalgDecomposeConv2DtoConv1D final : public impl::LinalgDecomposeConv2DToConv1DBase<LinalgDecomposeConv2DtoConv1D> {
+struct LinalgDecomposeConv2DtoConv1D final
+ : public impl::LinalgDecomposeConv2DToConv1DBase<
+ LinalgDecomposeConv2DtoConv1D> {
using Base::Base;
void runOnOperation() override {
``````````
</details>
https://github.com/llvm/llvm-project/pull/169082
More information about the Mlir-commits
mailing list