[Mlir-commits] [mlir] [mlir][linalg] Fix padding shape computation in PadTilingInterface for convs (PR #149576)
Vivian Zhang
llvmlistbot at llvm.org
Mon Jul 28 20:58:30 PDT 2025
https://github.com/yzhang93 updated https://github.com/llvm/llvm-project/pull/149576
>From a0b980a7a0188a82837fc165c5bcc13399bb7439 Mon Sep 17 00:00:00 2001
From: yzhang93 <zhyuhang88 at gmail.com>
Date: Fri, 18 Jul 2025 19:24:48 +0000
Subject: [PATCH 1/4] [mlir] Fix padding shape computation in
PadTilingInterface
---
.../Linalg/Transforms/PadTilingInterface.cpp | 17 +++-
...m-op-pad-tiling-interface-multiple-of.mlir | 87 +++++++++++++++++--
.../transform-op-pad-tiling-interface.mlir | 24 ++---
3 files changed, 104 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 5eb3761f7aca1..c465383771617 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -114,24 +114,31 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
/*compressDims=*/true);
// If we are padding to the next multiple of, compose with ceil(sz) * sz.
+ OpFoldResult paddingDimOfr;
if (options.padToMultipleOf) {
AffineExpr d0, s0;
bindDims(rewriter.getContext(), d0);
bindSymbols(rewriter.getContext(), s0);
AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
AffineMap composedMap = projectedMap.compose(ceilMap);
- OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+ paddingDimOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, composedMap,
{indexingSizes[paddingDim], paddingSize},
/*composeAffineMin=*/true);
- terms.push_back(paddingDimOfr);
} else {
// Otherwise just set to paddingSize.
- OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+ paddingDimOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, projectedMap, paddingSize);
- terms.push_back(paddingDimOfr);
}
+ // Adjust for the maximum accessed index which is (padding_size - 1).
+ AffineExpr d0;
+ bindDims(rewriter.getContext(), d0);
+ AffineMap subtractOneMap = AffineMap::get(1, 0, d0 - 1);
+ OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, subtractOneMap, {paddingDimOfr});
+ terms.push_back(maxAccessIdx);
+
LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
}
@@ -148,6 +155,8 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
AffineExpr sumExpr = dims.front();
for (unsigned i = 1; i < dims.size(); ++i)
sumExpr = sumExpr + dims[i];
+ // Add 1 to the maximum accessed index and get the final padded size.
+ sumExpr = sumExpr + rewriter.getAffineConstantExpr(1);
OpFoldResult paddedDimOfr =
affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms);
paddedShape[resultIndex] = paddedDimOfr;
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
index 78619b682673e..53cb7d7767b9a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
@@ -52,22 +52,22 @@ module {
// CHECK-LABEL: @generic
// CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>)
- func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>)
+ func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.
// CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[2, 0]
// CHECK: : tensor<7x5xf32> to tensor<9x5xf32>
// CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[2, 4, 2] {
- // CHECK: : tensor<7x11x12xf32> to tensor<9x15x14xf32>
+ // CHECK: : tensor<7x11x11xf32> to tensor<9x15x13xf32>
// CHECK-NEXT: linalg.generic
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<9x15x14xf32> to tensor<7x11x12xf32>
- %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<9x15x13xf32> to tensor<7x11x11xf32>
+ %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
- } -> tensor<7x11x12xf32>
- return %0 : tensor<7x11x12xf32>
+ } -> tensor<7x11x11xf32>
+ return %0 : tensor<7x11x11xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -83,7 +83,7 @@ module {
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 5)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 4)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -272,3 +272,74 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+// CHECK-LABEL: pad_conv
+func.func @pad_conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
+ // CHECK: : tensor<1x16x16x4xf32> to tensor<1x16x18x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+ // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16 + 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)>
+
+// CHECK-LABEL: pad_conv_dynamic
+func.func @pad_conv_dynamic(%arg0: tensor<1x16x?x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> {
+
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[D0_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+ // CHECK: %[[D0_1:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x16x?x4xf32>
+ // CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]], %[[D0_1]]]
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H0]], 12]
+ // CHECK: : tensor<1x16x?x4xf32> to tensor<1x16x?x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: %[[D1_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+ // CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D1_0]]]
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H1]], 0]
+ // CHECK: : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+ // CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, %[[D2_0]], 16] [1, 1, 1, 1] : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x?x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32>
+ return %0 : tensor<1x14x?x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index 26c03ed309c05..f7418769f79ca 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -69,22 +69,22 @@ module {
// CHECK-LABEL: @generic
// CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>)
- func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>)
+ func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.
// CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0]
// CHECK: : tensor<7x5xf32> to tensor<8x5xf32>
// CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] {
- // CHECK: : tensor<7x11x12xf32> to tensor<8x14x13xf32>
+ // CHECK: : tensor<7x11x11xf32> to tensor<8x14x12xf32>
// CHECK-NEXT: linalg.generic
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32>
- %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<8x14x12xf32> to tensor<7x11x11xf32>
+ %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
- } -> tensor<7x11x12xf32>
- return %0 : tensor<7x11x12xf32>
+ } -> tensor<7x11x11xf32>
+ return %0 : tensor<7x11x11xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -102,7 +102,7 @@ module {
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + 8)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 13)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 12)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -127,13 +127,13 @@ module {
// CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<?x11x?xf32>
// CHECK: %[[H2:.*]] = affine.apply #[[$MAP1]]()[%[[D2_0]]]
// CHECK: tensor.pad %{{.*}} low[0, 0, 0] high[%[[H1]], 3, %[[H2]]] {
- // CHECK: : tensor<?x11x?xf32> to tensor<8x14x13xf32>
+ // CHECK: : tensor<?x11x?xf32> to tensor<8x14x12xf32>
//
// CHECK: %[[D0_2:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x5xf32>
// CHECK: %[[D2_1:.*]] = affine.apply #[[$MAP2]]()[%[[D0_2]]]
- // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x13xf32>) {
- // CHECK: } -> tensor<8x14x13xf32>
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x13xf32> to tensor<?x11x?xf32>
+ // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x12xf32>) {
+ // CHECK: } -> tensor<8x14x12xf32>
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x12xf32> to tensor<?x11x?xf32>
//
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<?x5xf32>) outs(%arg1 : tensor<?x11x?xf32>) {
^bb0(%in: f32, %out: f32):
>From 01f2f3a93a0d13827a90eab69d07abd12f85fd4c Mon Sep 17 00:00:00 2001
From: yzhang93 <zhyuhang88 at gmail.com>
Date: Tue, 22 Jul 2025 04:48:22 +0000
Subject: [PATCH 2/4] Add cases for non-unit strides and dilations
---
.../Linalg/Transforms/PadTilingInterface.cpp | 30 ++++++++-
...m-op-pad-tiling-interface-multiple-of.mlir | 62 +++++++++++++++++++
2 files changed, 89 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index c465383771617..3f9b48c4fdbfb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -55,6 +55,28 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
return paddingSizes;
}
+/// Extracts the constant multiplier from an affine expression of the form
+/// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an
+/// AffineConstantExpr. Returns 1 if the expression is not a simple
+/// multiplication of a dimension and a constant.
+static int64_t extractConstantMultiplier(AffineExpr expr) {
+ if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) {
+ if (binOp.getKind() == AffineExprKind::Mul) {
+ auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS());
+ auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS());
+ if (lhsD && rhsC) {
+ return rhsC.getValue();
+ }
+ auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS());
+ auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS());
+ if (lhsC && rhsD) {
+ return lhsC.getValue();
+ }
+ }
+ }
+ return 1;
+}
+
/// Compute the padded shape of the given value `v` of `RankedTensorType` given
/// - `indexingSizes` a list of OpFoldResult.
/// - an `indexingMap` that encodes how the shape of varies with increases
@@ -131,12 +153,14 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
rewriter, loc, projectedMap, paddingSize);
}
- // Adjust for the maximum accessed index which is (padding_size - 1).
+ // Adjust for the maximum accessed index, which is (paddingSize - 1) *
+ // multiplier.
AffineExpr d0;
bindDims(rewriter.getContext(), d0);
- AffineMap subtractOneMap = AffineMap::get(1, 0, d0 - 1);
+ int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
+ AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
- rewriter, loc, subtractOneMap, {paddingDimOfr});
+ rewriter, loc, subtractMap, {paddingDimOfr});
terms.push_back(maxAccessIdx);
LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
index 53cb7d7767b9a..981f5dc37c859 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
@@ -343,3 +343,65 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: pad_conv_strided
+func.func @pad_conv_strided(%arg0: tensor<1x42x42x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 6, 12]
+ // CHECK: : tensor<1x42x42x4xf32> to tensor<1x42x48x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+ // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x42x42x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: pad_conv_dilated
+func.func @pad_conv_dilated(%arg0: tensor<1x18x18x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
+ // CHECK: : tensor<1x18x18x4xf32> to tensor<1x18x20x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+ // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x18x18x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
>From 79e5629af41d3f230a037020b5b993e6777e1f32 Mon Sep 17 00:00:00 2001
From: yzhang93 <zhyuhang88 at gmail.com>
Date: Mon, 28 Jul 2025 23:01:44 +0000
Subject: [PATCH 3/4] Add comments
---
.../mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td | 1 +
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 7 +++++++
mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp | 7 +++++++
3 files changed, 15 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index bafeca924e4c5..743285d21c55b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1191,6 +1191,7 @@ def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interfac
iteration domain induces a padding of the operands that is consistent
across the op semantics and, unlike for simple elementwise ops, may not be
trivially deducible or specifiable on operands only (e.g. convolutions).
+ Currently, only a limited set of projected permutation map is supported.
The specification of `padding_sizes` follows that of `tile_sizes` during
tiling: the value "0" on a particular iterator encode "no padding". Like in
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 9e62d0dcc7890..25aa759e02dcb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -601,6 +601,13 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
/// affine.apply operations.
/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and
/// provides a gentle portability path for Linalg-like ops with affine maps.
+/// The padded shape is computed by evaluating the maximum accessed index per
+/// dimension, which may involve multiplying by constant factors derived from
+/// the affine indexing expressions. Currently, only a limited set of projected
+/// permuation indexing map is supported, such as
+/// - affine_map<(d0, d1, d2) -> (d0, d1)>
+/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
+/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult>
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 3f9b48c4fdbfb..c9e3219b6f6e3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -85,6 +85,13 @@ static int64_t extractConstantMultiplier(AffineExpr expr) {
/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps.
/// The implementaiton below iteratively combines increases from contributing
/// dimensions using affine.apply operations.
+/// The padded shape is computed by evaluating the maximum accessed index per
+/// dimension, which may involve multiplying by constant factors derived from
+/// the affine indexing expressions. Currently, only a limited set of projected
+/// permutation indexing map is supported, such as
+/// - affine_map<(d0, d1, d2) -> (d0, d1)>
+/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
+/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult> linalg::computePaddedShape(
>From 735085a6e19dddf510e852000c7f35908efd53b6 Mon Sep 17 00:00:00 2001
From: yzhang93 <zhyuhang88 at gmail.com>
Date: Tue, 29 Jul 2025 03:57:45 +0000
Subject: [PATCH 4/4] Address comment
---
mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index c9e3219b6f6e3..abe7f6f08d5ab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -187,9 +187,8 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
for (unsigned i = 1; i < dims.size(); ++i)
sumExpr = sumExpr + dims[i];
// Add 1 to the maximum accessed index and get the final padded size.
- sumExpr = sumExpr + rewriter.getAffineConstantExpr(1);
- OpFoldResult paddedDimOfr =
- affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms);
+ OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, sumExpr + 1, terms);
paddedShape[resultIndex] = paddedDimOfr;
}
More information about the Mlir-commits
mailing list