[Mlir-commits] [mlir] [mlir] Fix padding shape computation in PadTilingInterface (PR #149576)

Vivian Zhang llvmlistbot at llvm.org
Mon Jul 21 21:49:20 PDT 2025


https://github.com/yzhang93 updated https://github.com/llvm/llvm-project/pull/149576

>From a0b980a7a0188a82837fc165c5bcc13399bb7439 Mon Sep 17 00:00:00 2001
From: yzhang93 <zhyuhang88 at gmail.com>
Date: Fri, 18 Jul 2025 19:24:48 +0000
Subject: [PATCH 1/2] [mlir] Fix padding shape computation in
 PadTilingInterface

---
 .../Linalg/Transforms/PadTilingInterface.cpp  | 17 +++-
 ...m-op-pad-tiling-interface-multiple-of.mlir | 87 +++++++++++++++++--
 .../transform-op-pad-tiling-interface.mlir    | 24 ++---
 3 files changed, 104 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 5eb3761f7aca1..c465383771617 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -114,24 +114,31 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
                             /*compressDims=*/true);
 
       // If we are padding to the next multiple of, compose with ceil(sz) * sz.
+      OpFoldResult paddingDimOfr;
       if (options.padToMultipleOf) {
         AffineExpr d0, s0;
         bindDims(rewriter.getContext(), d0);
         bindSymbols(rewriter.getContext(), s0);
         AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
         AffineMap composedMap = projectedMap.compose(ceilMap);
-        OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+        paddingDimOfr = affine::makeComposedFoldedAffineApply(
             rewriter, loc, composedMap,
             {indexingSizes[paddingDim], paddingSize},
             /*composeAffineMin=*/true);
-        terms.push_back(paddingDimOfr);
       } else {
         // Otherwise just set to paddingSize.
-        OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+        paddingDimOfr = affine::makeComposedFoldedAffineApply(
             rewriter, loc, projectedMap, paddingSize);
-        terms.push_back(paddingDimOfr);
       }
 
+      // Adjust for the maximum accessed index which is (padding_size - 1).
+      AffineExpr d0;
+      bindDims(rewriter.getContext(), d0);
+      AffineMap subtractOneMap = AffineMap::get(1, 0, d0 - 1);
+      OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
+          rewriter, loc, subtractOneMap, {paddingDimOfr});
+      terms.push_back(maxAccessIdx);
+
       LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
     }
 
@@ -148,6 +155,8 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
     AffineExpr sumExpr = dims.front();
     for (unsigned i = 1; i < dims.size(); ++i)
       sumExpr = sumExpr + dims[i];
+    // Add 1 to the maximum accessed index and get the final padded size.
+    sumExpr = sumExpr + rewriter.getAffineConstantExpr(1);
     OpFoldResult paddedDimOfr =
         affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms);
     paddedShape[resultIndex] = paddedDimOfr;
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
index 78619b682673e..53cb7d7767b9a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
@@ -52,22 +52,22 @@ module {
 
 // CHECK-LABEL: @generic
 // CHECK-SAME:      %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME:      %[[T1:.*]]: tensor<7x11x12xf32>)
-  func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME:      %[[T1:.*]]: tensor<7x11x11xf32>)
+  func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
 
   //  CHECK-DAG: %[[CST:.*]] = arith.constant 0.
 
   //      CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[2, 0]
   //      CHECK:   : tensor<7x5xf32> to tensor<9x5xf32>
   //      CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[2, 4, 2] {
-  //      CHECK:   : tensor<7x11x12xf32> to tensor<9x15x14xf32>
+  //      CHECK:   : tensor<7x11x11xf32> to tensor<9x15x13xf32>
   // CHECK-NEXT: linalg.generic
-  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<9x15x14xf32> to tensor<7x11x12xf32>
-  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<9x15x13xf32> to tensor<7x11x11xf32>
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
     ^bb0(%in: f32, %out: f32):
       linalg.yield %in : f32
-    } -> tensor<7x11x12xf32>
-    return %0 : tensor<7x11x12xf32>
+    } -> tensor<7x11x11xf32>
+    return %0 : tensor<7x11x11xf32>
   }
   module attributes {transform.with_named_sequence} {
     transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -83,7 +83,7 @@ module {
 // -----
 
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 5)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 4)>
 // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
 
 #map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -272,3 +272,74 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+//     CHECK-LABEL: pad_conv
+func.func @pad_conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
+  //      CHECK:   : tensor<1x16x16x4xf32> to tensor<1x16x18x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+  //      CHECK:   : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+  //      CHECK:   : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+  // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+  %0 = linalg.conv_2d_nhwc_fhwc
+    {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+      ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
+    outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+  return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+      padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16 + 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)>
+
+//     CHECK-LABEL: pad_conv_dynamic
+func.func @pad_conv_dynamic(%arg0: tensor<1x16x?x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> {
+
+  //  CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  //      CHECK: %[[D0_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+  //      CHECK: %[[D0_1:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x16x?x4xf32>
+  //      CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]], %[[D0_1]]]
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H0]], 12]
+  //      CHECK:   : tensor<1x16x?x4xf32> to tensor<1x16x?x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+  //      CHECK:   : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+  //      CHECK: %[[D1_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+  //      CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D1_0]]]
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H1]], 0]
+  //      CHECK:   : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+  //      CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+  // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, %[[D2_0]], 16] [1, 1, 1, 1] : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+
+  %0 = linalg.conv_2d_nhwc_fhwc
+    {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+      ins(%arg0, %arg1: tensor<1x16x?x4xf32>, tensor<16x3x3x4xf32>)
+    outs(%arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32>
+  return %0 : tensor<1x14x?x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+      padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index 26c03ed309c05..f7418769f79ca 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -69,22 +69,22 @@ module {
 
 // CHECK-LABEL: @generic
 // CHECK-SAME:      %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME:      %[[T1:.*]]: tensor<7x11x12xf32>)
-  func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME:      %[[T1:.*]]: tensor<7x11x11xf32>)
+  func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
 
   //  CHECK-DAG: %[[CST:.*]] = arith.constant 0.
 
   //      CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0]
   //      CHECK:   : tensor<7x5xf32> to tensor<8x5xf32>
   //      CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] {
-  //      CHECK:   : tensor<7x11x12xf32> to tensor<8x14x13xf32>
+  //      CHECK:   : tensor<7x11x11xf32> to tensor<8x14x12xf32>
   // CHECK-NEXT: linalg.generic
-  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32>
-  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<8x14x12xf32> to tensor<7x11x11xf32>
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
     ^bb0(%in: f32, %out: f32):
       linalg.yield %in : f32
-    } -> tensor<7x11x12xf32>
-    return %0 : tensor<7x11x12xf32>
+    } -> tensor<7x11x11xf32>
+    return %0 : tensor<7x11x11xf32>
   }
   module attributes {transform.with_named_sequence} {
     transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -102,7 +102,7 @@ module {
 
 
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + 8)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 13)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 12)>
 // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
 
 #map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -127,13 +127,13 @@ module {
   //      CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<?x11x?xf32>
   //      CHECK: %[[H2:.*]] = affine.apply #[[$MAP1]]()[%[[D2_0]]]
   //      CHECK: tensor.pad %{{.*}} low[0, 0, 0] high[%[[H1]], 3, %[[H2]]] {
-  //      CHECK:   : tensor<?x11x?xf32> to tensor<8x14x13xf32>
+  //      CHECK:   : tensor<?x11x?xf32> to tensor<8x14x12xf32>
   //
   //      CHECK: %[[D0_2:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x5xf32>
   //      CHECK: %[[D2_1:.*]] = affine.apply #[[$MAP2]]()[%[[D0_2]]]
-  //      CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x13xf32>) {
-  //      CHECK: } -> tensor<8x14x13xf32>
-  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x13xf32> to tensor<?x11x?xf32>
+  //      CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x12xf32>) {
+  //      CHECK: } -> tensor<8x14x12xf32>
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x12xf32> to tensor<?x11x?xf32>
   //
   %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<?x5xf32>) outs(%arg1 : tensor<?x11x?xf32>) {
     ^bb0(%in: f32, %out: f32):

>From 01f2f3a93a0d13827a90eab69d07abd12f85fd4c Mon Sep 17 00:00:00 2001
From: yzhang93 <zhyuhang88 at gmail.com>
Date: Tue, 22 Jul 2025 04:48:22 +0000
Subject: [PATCH 2/2] Add cases for non-unit strides and dilations

---
 .../Linalg/Transforms/PadTilingInterface.cpp  | 30 ++++++++-
 ...m-op-pad-tiling-interface-multiple-of.mlir | 62 +++++++++++++++++++
 2 files changed, 89 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index c465383771617..3f9b48c4fdbfb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -55,6 +55,28 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
   return paddingSizes;
 }
 
+/// Extracts the constant multiplier from an affine expression of the form
+/// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an
+/// AffineConstantExpr. Returns 1 if the expression is not a simple
+/// multiplication of a dimension and a constant.
+static int64_t extractConstantMultiplier(AffineExpr expr) {
+  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) {
+    if (binOp.getKind() == AffineExprKind::Mul) {
+      auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS());
+      auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS());
+      if (lhsD && rhsC) {
+        return rhsC.getValue();
+      }
+      auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS());
+      auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS());
+      if (lhsC && rhsD) {
+        return lhsC.getValue();
+      }
+    }
+  }
+  return 1;
+}
+
 /// Compute the padded shape of the given value `v` of `RankedTensorType` given
 ///   - `indexingSizes` a list of OpFoldResult.
 ///   - an `indexingMap` that encodes how the shape of varies with increases
@@ -131,12 +153,14 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
             rewriter, loc, projectedMap, paddingSize);
       }
 
-      // Adjust for the maximum accessed index which is (padding_size - 1).
+      // Adjust for the maximum accessed index, which is (paddingSize - 1) *
+      // multiplier.
       AffineExpr d0;
       bindDims(rewriter.getContext(), d0);
-      AffineMap subtractOneMap = AffineMap::get(1, 0, d0 - 1);
+      int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
+      AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
       OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
-          rewriter, loc, subtractOneMap, {paddingDimOfr});
+          rewriter, loc, subtractMap, {paddingDimOfr});
       terms.push_back(maxAccessIdx);
 
       LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
index 53cb7d7767b9a..981f5dc37c859 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
@@ -343,3 +343,65 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+//     CHECK-LABEL: pad_conv_strided
+func.func @pad_conv_strided(%arg0: tensor<1x42x42x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 6, 12]
+  //      CHECK:   : tensor<1x42x42x4xf32> to tensor<1x42x48x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+  //      CHECK:   : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+  //      CHECK:   : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+  // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+  %0 = linalg.conv_2d_nhwc_fhwc
+    {dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> }
+      ins(%arg0, %arg1: tensor<1x42x42x4xf32>, tensor<16x3x3x4xf32>)
+    outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+  return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+      padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+//     CHECK-LABEL: pad_conv_dilated
+func.func @pad_conv_dilated(%arg0: tensor<1x18x18x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
+  //      CHECK:   : tensor<1x18x18x4xf32> to tensor<1x18x20x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+  //      CHECK:   : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+  //      CHECK:   : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+  // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+  %0 = linalg.conv_2d_nhwc_fhwc
+    {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+      ins(%arg0, %arg1: tensor<1x18x18x4xf32>, tensor<16x3x3x4xf32>)
+    outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+  return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+      padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list