[Mlir-commits] [mlir] dc6d7f0 - [mlir][linalg] Fix padding shape computation in PadTilingInterface for convs (#149576)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 29 09:58:34 PDT 2025


Author: Vivian Zhang
Date: 2025-07-29T09:58:30-07:00
New Revision: dc6d7f0637e7c80e39e8b7f0e8b61515b4961b0f

URL: https://github.com/llvm/llvm-project/commit/dc6d7f0637e7c80e39e8b7f0e8b61515b4961b0f
DIFF: https://github.com/llvm/llvm-project/commit/dc6d7f0637e7c80e39e8b7f0e8b61515b4961b0f.diff

LOG: [mlir][linalg] Fix padding shape computation in PadTilingInterface for convs (#149576)

This PR fixes the computation of padded shapes for convolution-style
affine maps (e.g., d0 + d1) in `PadTilingInterface`. Previously, the
codes used the direct sum of loop upper bounds, leading to over-padding.
For example, the following `conv_2d_nhwc_fhwc` op, if only padding the c
dimensions to multiples of 16, it also incorrectly pads the convolved
dimensions and generates the wrong input shape as:

```
%padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 1, 1, 12] {
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
  tensor.yield %cst : f32
} : tensor<1x16x16x4xf32> to tensor<1x17x17x16xf32>
%padded_0 = tensor.pad %arg1 low[0, 0, 0, 0] high[0, 0, 0, 12] {
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
  tensor.yield %cst : f32
} : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%padded, %padded_0 : tensor<1x17x17x16xf32>, tensor<16x3x3x16xf32>) outs(%arg2 : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
return %0 : tensor<1x14x14x16xf32>
```

The new implementation uses the maximum accessed index as the input for
affine map and then adds 1 after aggregating all the terms to get the
final padded size. This fixed
https://github.com/llvm/llvm-project/issues/148679.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
    mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
    mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 8d45c40a93e2b..61ce23f07faa8 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1191,6 +1191,7 @@ def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interfac
     iteration domain induces a padding of the operands that is consistent
     across the op semantics and, unlike for simple elementwise ops, may not be
     trivially deducible or specifiable on operands only (e.g. convolutions).
+    Currently, only a limited set of projected permutation maps are supported.
     
     The specification of `padding_sizes` follows that of `tile_sizes` during
     tiling: the value "0" on a particular iterator encode "no padding". Like in

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index e625eefac5f78..d4ffe0a91fcfe 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -611,6 +611,13 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
 /// affine.apply operations.
 /// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and
 /// provides a gentle portability path for Linalg-like ops with affine maps.
+/// The padded shape is computed by evaluating the maximum accessed index per
+/// dimension, which may involve multiplying by constant factors derived from
+/// the affine indexing expressions. Currently, only a limited set of projected
+/// permuation indexing maps are supported, such as
+/// - affine_map<(d0, d1, d2) -> (d0, d1)>
+/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
+/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
 /// In the future, more general interfaces can be devised to encode similar
 /// shape evolutions and map between an op and its operands.
 SmallVector<OpFoldResult>

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 2c62cb628a7dd..2e6252336dfeb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -55,6 +55,28 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
   return paddingSizes;
 }
 
+/// Extracts the constant multiplier from an affine expression of the form
+/// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an
+/// AffineConstantExpr. Returns 1 if the expression is not a simple
+/// multiplication of a dimension and a constant.
+static int64_t extractConstantMultiplier(AffineExpr expr) {
+  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) {
+    if (binOp.getKind() == AffineExprKind::Mul) {
+      auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS());
+      auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS());
+      if (lhsD && rhsC) {
+        return rhsC.getValue();
+      }
+      auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS());
+      auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS());
+      if (lhsC && rhsD) {
+        return lhsC.getValue();
+      }
+    }
+  }
+  return 1;
+}
+
 /// Compute the padded shape of the given value `v` of `RankedTensorType` given
 ///   - `indexingSizes` a list of OpFoldResult.
 ///   - an `indexingMap` that encodes how the shape of varies with increases
@@ -63,6 +85,13 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
 /// The `indexingMap` + `indexingSizes` encoding suits StructuredOps.
 /// The implementaiton below iteratively combines increases from contributing
 /// dimensions using affine.apply operations.
+/// The padded shape is computed by evaluating the maximum accessed index per
+/// dimension, which may involve multiplying by constant factors derived from
+/// the affine indexing expressions. Currently, only a limited set of projected
+/// permutation indexing maps are supported, such as
+/// - affine_map<(d0, d1, d2) -> (d0, d1)>
+/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
+/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
 /// In the future, more general interfaces can be devised to encode similar
 /// shape evolutions and map between an op and its operands.
 SmallVector<OpFoldResult> linalg::computePaddedShape(
@@ -114,24 +143,33 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
                             /*compressDims=*/true);
 
       // If we are padding to the next multiple of, compose with ceil(sz) * sz.
+      OpFoldResult paddingDimOfr;
       if (options.padToMultipleOf) {
         AffineExpr d0, s0;
         bindDims(rewriter.getContext(), d0);
         bindSymbols(rewriter.getContext(), s0);
         AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
         AffineMap composedMap = projectedMap.compose(ceilMap);
-        OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+        paddingDimOfr = affine::makeComposedFoldedAffineApply(
             rewriter, loc, composedMap,
             {indexingSizes[paddingDim], paddingSize},
             /*composeAffineMin=*/true);
-        terms.push_back(paddingDimOfr);
       } else {
         // Otherwise just set to paddingSize.
-        OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+        paddingDimOfr = affine::makeComposedFoldedAffineApply(
             rewriter, loc, projectedMap, paddingSize);
-        terms.push_back(paddingDimOfr);
       }
 
+      // Adjust for the maximum accessed index, which is (paddingSize - 1) *
+      // multiplier.
+      AffineExpr d0;
+      bindDims(rewriter.getContext(), d0);
+      int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
+      AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
+      OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
+          rewriter, loc, subtractMap, {paddingDimOfr});
+      terms.push_back(maxAccessIdx);
+
       LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
     }
 
@@ -148,8 +186,9 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
     AffineExpr sumExpr = dims.front();
     for (unsigned i = 1; i < dims.size(); ++i)
       sumExpr = sumExpr + dims[i];
-    OpFoldResult paddedDimOfr =
-        affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms);
+    // Add 1 to the maximum accessed index and get the final padded size.
+    OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, sumExpr + 1, terms);
     paddedShape[resultIndex] = paddedDimOfr;
   }
 

diff  --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
index 78619b682673e..981f5dc37c859 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
@@ -52,22 +52,22 @@ module {
 
 // CHECK-LABEL: @generic
 // CHECK-SAME:      %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME:      %[[T1:.*]]: tensor<7x11x12xf32>)
-  func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME:      %[[T1:.*]]: tensor<7x11x11xf32>)
+  func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
 
   //  CHECK-DAG: %[[CST:.*]] = arith.constant 0.
 
   //      CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[2, 0]
   //      CHECK:   : tensor<7x5xf32> to tensor<9x5xf32>
   //      CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[2, 4, 2] {
-  //      CHECK:   : tensor<7x11x12xf32> to tensor<9x15x14xf32>
+  //      CHECK:   : tensor<7x11x11xf32> to tensor<9x15x13xf32>
   // CHECK-NEXT: linalg.generic
-  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<9x15x14xf32> to tensor<7x11x12xf32>
-  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<9x15x13xf32> to tensor<7x11x11xf32>
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
     ^bb0(%in: f32, %out: f32):
       linalg.yield %in : f32
-    } -> tensor<7x11x12xf32>
-    return %0 : tensor<7x11x12xf32>
+    } -> tensor<7x11x11xf32>
+    return %0 : tensor<7x11x11xf32>
   }
   module attributes {transform.with_named_sequence} {
     transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -83,7 +83,7 @@ module {
 // -----
 
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 5)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 4)>
 // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
 
 #map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -272,3 +272,136 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+//     CHECK-LABEL: pad_conv
+func.func @pad_conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
+  //      CHECK:   : tensor<1x16x16x4xf32> to tensor<1x16x18x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+  //      CHECK:   : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+  //      CHECK:   : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+  // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+  %0 = linalg.conv_2d_nhwc_fhwc
+    {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+      ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
+    outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+  return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+      padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16 + 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)>
+
+//     CHECK-LABEL: pad_conv_dynamic
+func.func @pad_conv_dynamic(%arg0: tensor<1x16x?x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> {
+
+  //  CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  //      CHECK: %[[D0_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+  //      CHECK: %[[D0_1:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x16x?x4xf32>
+  //      CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]], %[[D0_1]]]
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H0]], 12]
+  //      CHECK:   : tensor<1x16x?x4xf32> to tensor<1x16x?x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+  //      CHECK:   : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+  //      CHECK: %[[D1_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+  //      CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D1_0]]]
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H1]], 0]
+  //      CHECK:   : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+  //      CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+  // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, %[[D2_0]], 16] [1, 1, 1, 1] : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+
+  %0 = linalg.conv_2d_nhwc_fhwc
+    {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+      ins(%arg0, %arg1: tensor<1x16x?x4xf32>, tensor<16x3x3x4xf32>)
+    outs(%arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32>
+  return %0 : tensor<1x14x?x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+      padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+//     CHECK-LABEL: pad_conv_strided
+func.func @pad_conv_strided(%arg0: tensor<1x42x42x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 6, 12]
+  //      CHECK:   : tensor<1x42x42x4xf32> to tensor<1x42x48x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+  //      CHECK:   : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+  //      CHECK:   : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+  // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+  %0 = linalg.conv_2d_nhwc_fhwc
+    {dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> }
+      ins(%arg0, %arg1: tensor<1x42x42x4xf32>, tensor<16x3x3x4xf32>)
+    outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+  return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+      padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+//     CHECK-LABEL: pad_conv_dilated
+func.func @pad_conv_dilated(%arg0: tensor<1x18x18x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
+  //      CHECK:   : tensor<1x18x18x4xf32> to tensor<1x18x20x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+  //      CHECK:   : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+  //      CHECK:   : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+  // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+  %0 = linalg.conv_2d_nhwc_fhwc
+    {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+      ins(%arg0, %arg1: tensor<1x18x18x4xf32>, tensor<16x3x3x4xf32>)
+    outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+  return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+      padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index 26c03ed309c05..f7418769f79ca 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -69,22 +69,22 @@ module {
 
 // CHECK-LABEL: @generic
 // CHECK-SAME:      %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME:      %[[T1:.*]]: tensor<7x11x12xf32>)
-  func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME:      %[[T1:.*]]: tensor<7x11x11xf32>)
+  func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
 
   //  CHECK-DAG: %[[CST:.*]] = arith.constant 0.
 
   //      CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0]
   //      CHECK:   : tensor<7x5xf32> to tensor<8x5xf32>
   //      CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] {
-  //      CHECK:   : tensor<7x11x12xf32> to tensor<8x14x13xf32>
+  //      CHECK:   : tensor<7x11x11xf32> to tensor<8x14x12xf32>
   // CHECK-NEXT: linalg.generic
-  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32>
-  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<8x14x12xf32> to tensor<7x11x11xf32>
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
     ^bb0(%in: f32, %out: f32):
       linalg.yield %in : f32
-    } -> tensor<7x11x12xf32>
-    return %0 : tensor<7x11x12xf32>
+    } -> tensor<7x11x11xf32>
+    return %0 : tensor<7x11x11xf32>
   }
   module attributes {transform.with_named_sequence} {
     transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -102,7 +102,7 @@ module {
 
 
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + 8)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 13)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 12)>
 // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
 
 #map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -127,13 +127,13 @@ module {
   //      CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<?x11x?xf32>
   //      CHECK: %[[H2:.*]] = affine.apply #[[$MAP1]]()[%[[D2_0]]]
   //      CHECK: tensor.pad %{{.*}} low[0, 0, 0] high[%[[H1]], 3, %[[H2]]] {
-  //      CHECK:   : tensor<?x11x?xf32> to tensor<8x14x13xf32>
+  //      CHECK:   : tensor<?x11x?xf32> to tensor<8x14x12xf32>
   //
   //      CHECK: %[[D0_2:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x5xf32>
   //      CHECK: %[[D2_1:.*]] = affine.apply #[[$MAP2]]()[%[[D0_2]]]
-  //      CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x13xf32>) {
-  //      CHECK: } -> tensor<8x14x13xf32>
-  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x13xf32> to tensor<?x11x?xf32>
+  //      CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x12xf32>) {
+  //      CHECK: } -> tensor<8x14x12xf32>
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x12xf32> to tensor<?x11x?xf32>
   //
   %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<?x5xf32>) outs(%arg1 : tensor<?x11x?xf32>) {
     ^bb0(%in: f32, %out: f32):


        


More information about the Mlir-commits mailing list