[Mlir-commits] [mlir] 0609eb1 - [mlir][linalg] Remove padding from tiling options.
Tobias Gysi
llvmlistbot at llvm.org
Wed Nov 10 05:35:37 PST 2021
Author: Tobias Gysi
Date: 2021-11-10T13:33:28Z
New Revision: 0609eb1b32c26d5c3a440e413ea79191151ecf10
URL: https://github.com/llvm/llvm-project/commit/0609eb1b32c26d5c3a440e413ea79191151ecf10
DIFF: https://github.com/llvm/llvm-project/commit/0609eb1b32c26d5c3a440e413ea79191151ecf10.diff
LOG: [mlir][linalg] Remove padding from tiling options.
Remove the padding options from the tiling options since padding is now implemented by a separate pattern/pass introduced in https://reviews.llvm.org/D112412.
The revsion remove the tile-and-pad-tensors.mlir and replaces it with the pad.mlir that tests padding in isolation (without tiling). Similarly, hoist-padding.mlir is replaced by pad-and-hoist.mlir introduced in https://reviews.llvm.org/D112713.
Depends On D112838
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D113382
Added:
mlir/test/Dialect/Linalg/pad.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
mlir/test/Dialect/Linalg/hoist-padding.mlir
mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5e38ae3acebf0..e05dde330797b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -577,31 +577,6 @@ struct LinalgTilingOptions {
return *this;
}
- /// Callback returning the padding value to use for a given OpOperand or
- /// failure for no padding. Padding operations are introduced if
- /// `paddingValueComputationFunction` is set and does not return failure.
- /// Padding all operands guarantees the operation is statically shaped and
- /// thus can be vectorized.
- PaddingValueComputationFunction paddingValueComputationFunction = nullptr;
-
- LinalgTilingOptions &
- setPaddingValueComputationFunction(PaddingValueComputationFunction fun) {
- paddingValueComputationFunction = std::move(fun);
- return *this;
- }
-
- /// Callback returning true if the pad tensor operation defining the given
- /// OpOperand shall be marked as nofold to enable packing. A padding operation
- /// is only marked nofold if `paddingNoFoldComputationFunction` is set and
- /// returns true. Otherwise, the nofold attribute is set to false.
- PaddingNoFoldComputationFunction paddingNoFoldComputationFunction = nullptr;
-
- LinalgTilingOptions &
- setPaddingNoFoldComputationFunction(PaddingNoFoldComputationFunction fun) {
- paddingNoFoldComputationFunction = std::move(fun);
- return *this;
- }
-
/// Peel the specified loops.
SmallVector<int64_t> peeledLoops;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index af3e528212f7f..a0c47ad9aeaaf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -335,30 +335,8 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
// Peel loops.
peelLoops(rewriter, *res, options);
- // Consider padding on the fly only if the op has tensor semantics.
- if (!options.paddingValueComputationFunction ||
- !linalgOp.hasTensorSemantics()) {
- result = *res;
- return success();
- }
-
- // Try to pad on the fly by rewriting res->op as a padded op. If successful,
- // `res.op` is rewritten in static form with padded operands.
- LinalgOp paddedOp;
- FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
- rewriter, res->op, options.paddingValueComputationFunction,
- options.paddingNoFoldComputationFunction, paddedOp);
- if (succeeded(newResults)) {
- rewriter.replaceOp(res->op, newResults.getValue());
- filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
- res->op = paddedOp;
- result = *res;
- // Do not perform replacement of `linalgOp`, let the derived patterns
- // do this as they see fit, from the resulting TiledLinalgOp.
- return success();
- }
- // Set so RAII guard does not propagate TiledLinalgOp to `result`.
- return failure();
+ result = *res;
+ return success();
}
static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
diff --git a/mlir/test/Dialect/Linalg/hoist-padding.mlir b/mlir/test/Dialect/Linalg/hoist-padding.mlir
deleted file mode 100644
index 8f24fe81a9eaa..0000000000000
--- a/mlir/test/Dialect/Linalg/hoist-padding.mlir
+++ /dev/null
@@ -1,277 +0,0 @@
-// Specific structural checks are performed on 2-level hoisting
-// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-hoist-padding=2 -canonicalize | FileCheck %s
-
-// IR verification is performed on [0-6]-level hoisting
-// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-hoist-padding=0 | FileCheck %s --check-prefix=VERIFIER-ONLY
-// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-hoist-padding=1 | FileCheck %s --check-prefix=VERIFIER-ONLY
-// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-hoist-padding=3 | FileCheck %s --check-prefix=VERIFIER-ONLY
-// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-hoist-padding=4 | FileCheck %s --check-prefix=VERIFIER-ONLY
-// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-hoist-padding=5 | FileCheck %s --check-prefix=VERIFIER-ONLY
-// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-hoist-padding=6 | FileCheck %s --check-prefix=VERIFIER-ONLY
-
-// CHECK-DAG: #[[$DIV3:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 3)>
-// CHECK-DAG: #[[$DIV4:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 4)>
-// CHECK-DAG: #[[$DIVS3:[0-9a-z]+]] = affine_map<()[s0] -> (s0 ceildiv 3)>
-// CHECK-DAG: #[[$DIVS4:[0-9a-z]+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
-#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)>
-#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
-#map2 = affine_map<(d0)[s0] -> (3, -d0 + s0)>
-#map3 = affine_map<(d0, d1) -> (2, d0 - d1)>
-#map4 = affine_map<(d0, d1) -> (3, d0 - d1)>
-
-// CHECK-LABEL: func @matmul_tensors
-// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor
-// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor
-// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor
-// VERIFIER-ONLY-LABEL: func @matmul_tensors
-func @matmul_tensors(
- %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
- -> tensor<?x?xf32>
-{
- %c2 = arith.constant 2 : index
- %c3 = arith.constant 3 : index
- %c4 = arith.constant 4 : index
- %cst = arith.constant 0.000000e+00 : f32
-
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
-
- // CHECK-DAG: %[[dM:.*]] = tensor.dim %[[TA]], %[[C0]] : tensor<?x?xf32>
- // CHECK-DAG: %[[dK:.*]] = tensor.dim %[[TA]], %[[C1]] : tensor<?x?xf32>
- // CHECK-DAG: %[[dN:.*]] = tensor.dim %[[TB]], %[[C1]] : tensor<?x?xf32>
- %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
- %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
- %2 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
-
- // CHECK: scf.for %[[I:[0-9a-z]+]] =
- // First padded tensor is MxKx2x4 under loop M so Kx2x4
- // CHECK: %[[SZpad0_K:[0-9]+]] = affine.apply #[[$DIVS4]]()[%[[dK]]]
- // CHECK: linalg.init_tensor [%[[SZpad0_K]], 2, 4] : tensor<?x2x4xf32>
- // 1-D loop
- // CHECK: %[[A:.*]] = scf.for %[[J1:[0-9a-z]+]] =
- // Iteration count along J1
- // CHECK: %[[IDXpad0_K:[0-9]+]] = affine.apply #[[$DIV4]](%[[J1]])
- // CHECK: tensor.extract_slice %{{.*}} [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- // CHECK: linalg.pad_tensor %{{.*}}
- // CHECK: : tensor<?x?xf32> to tensor<2x4xf32>
- // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%[[IDXpad0_K]], 0, 0]
- // CHECK-SAME: [1, 2, 4] [1, 1, 1] : tensor<2x4xf32> into tensor<?x2x4xf32>
- // Second tensor is KxN but loop order is (M, N, K) so padded tensor is NxKx4x3
- // CHECK: %[[SZpad1_N:[0-9]+]] = affine.apply #[[$DIVS3]]()[%[[dN]]]
- // CHECK: %[[SZpad1_K:[0-9]+]] = affine.apply #[[$DIVS4]]()[%[[dK]]]
- // CHECK: linalg.init_tensor [%[[SZpad1_N]], %[[SZpad1_K]], 4, 3] : tensor<?x?x4x3xf32>
- // 2-D loop
- // CHECK: %[[B:.*]] = scf.for %[[K2:[0-9a-z]+]] =
- // Iteration count along K2
- // CHECK: %[[IDXpad1_K:[0-9]+]] = affine.apply #[[$DIV3]](%[[K2]])
- // CHECK: scf.for %[[J2:[0-9a-z]+]] =
- // Iteration count along J2
- // CHECK: %[[IDXpad1_N:[0-9]+]] = affine.apply #[[$DIV4]](%[[J2]])
- // CHECK: tensor.extract_slice %{{.*}} [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- // CHECK: linalg.pad_tensor %{{.*}}
- // CHECK: : tensor<?x?xf32> to tensor<4x3xf32>
- // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%[[IDXpad1_K]], %[[IDXpad1_N]], 0, 0]
- // CHECK-SAME: [1, 1, 4, 3] [1, 1, 1, 1] : tensor<4x3xf32> into tensor<?x?x4x3xf32>
- // 2-D loop
- // CHECK: scf.for %[[J:[0-9a-zA-Z]+]]
- // CHECK: scf.for %[[K:[0-9a-zA-Z]+]]
- // Iteration count along K
- // CHECK: %[[IDXpad0_K:[0-9]+]] = affine.apply #[[$DIV4]](%[[K]])
- // CHECK: %[[stA:.*]] = tensor.extract_slice %[[A]][%[[IDXpad0_K]], 0, 0] [1, 2, 4] [1, 1, 1] :
- // CHECK-SAME: tensor<?x2x4xf32> to tensor<2x4xf32>
- // Iteration count along J
- // CHECK: %[[IDXpad1_N:[0-9]+]] = affine.apply #[[$DIV3]](%[[J]])
- // Iteration count along K
- // CHECK: %[[IDXpad1_K:[0-9]+]] = affine.apply #[[$DIV4]](%[[K]])
- // CHECK: %[[stB:.*]] = tensor.extract_slice %[[B]][%[[IDXpad1_N]], %[[IDXpad1_K]], 0, 0] [1, 1, 4, 3] [1, 1, 1, 1] :
- // CHECK-SAME: tensor<?x?x4x3xf32> to tensor<4x3xf32>
- // CHECK: %[[stC:.*]] = linalg.pad_tensor %{{.*}}
- // CHECK: : tensor<?x?xf32> to tensor<2x3xf32>
- // CHECK: linalg.matmul ins(%[[stA]], %[[stB]] : tensor<2x4xf32>, tensor<4x3xf32>)
- // CHECK-SAME: outs(%[[stC]] : tensor<2x3xf32>) -> tensor<2x3xf32>
- %3 = scf.for %arg3 = %c0 to %0 step %c2 iter_args(%arg4 = %arg2) -> (tensor<?x?xf32>) {
- %4 = scf.for %arg5 = %c0 to %2 step %c3 iter_args(%arg6 = %arg4) -> (tensor<?x?xf32>) {
- %5 = scf.for %arg7 = %c0 to %1 step %c4 iter_args(%arg8 = %arg6) -> (tensor<?x?xf32>) {
- %6 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
- %7 = affine.min #map0(%arg3)[%6]
- %8 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
- %9 = affine.min #map1(%arg7)[%8]
- %10 = tensor.extract_slice %arg0[%arg3, %arg7] [%7, %9] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- %11 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
- %12 = affine.min #map1(%arg7)[%11]
- %13 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
- %14 = affine.min #map2(%arg5)[%13]
- %15 = tensor.extract_slice %arg1[%arg7, %arg5] [%12, %14] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- %16 = tensor.dim %arg8, %c0 : tensor<?x?xf32>
- %17 = affine.min #map3(%16, %arg3)
- %18 = tensor.dim %arg8, %c1 : tensor<?x?xf32>
- %19 = affine.min #map4(%18, %arg5)
- %20 = tensor.extract_slice %arg8[%arg3, %arg5] [%17, %19] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- %21 = arith.subi %c2, %7 : index
- %22 = arith.subi %c4, %9 : index
- %23 = linalg.pad_tensor %10 low[%c0, %c0] high[%21, %22] {
- ^bb0(%arg9: index, %arg10: index): // no predecessors
- linalg.yield %cst : f32
- } : tensor<?x?xf32> to tensor<2x4xf32>
- %24 = arith.subi %c4, %12 : index
- %25 = arith.subi %c3, %14 : index
- %26 = linalg.pad_tensor %15 low[%c0, %c0] high[%24, %25] {
- ^bb0(%arg9: index, %arg10: index): // no predecessors
- linalg.yield %cst : f32
- } : tensor<?x?xf32> to tensor<4x3xf32>
- %27 = arith.subi %c2, %17 : index
- %28 = arith.subi %c3, %19 : index
- %29 = linalg.pad_tensor %20 low[%c0, %c0] high[%27, %28] {
- ^bb0(%arg9: index, %arg10: index): // no predecessors
- linalg.yield %cst : f32
- } : tensor<?x?xf32> to tensor<2x3xf32>
- %30 = linalg.matmul ins(%23, %26 : tensor<2x4xf32>, tensor<4x3xf32>) outs(%29 : tensor<2x3xf32>) -> tensor<2x3xf32>
- %31 = tensor.extract_slice %30[0, 0] [%7, %14] [1, 1] : tensor<2x3xf32> to tensor<?x?xf32>
- %32 = tensor.insert_slice %31 into %arg8[%arg3, %arg5] [%17, %19] [%c1, %c1] : tensor<?x?xf32> into tensor<?x?xf32>
- scf.yield %32 : tensor<?x?xf32>
- }
- scf.yield %5 : tensor<?x?xf32>
- }
- scf.yield %4 : tensor<?x?xf32>
- }
- return %3 : tensor<?x?xf32>
-}
-
-// -----
-
-// CHECK-DAG: #[[$MIN_REST8:[0-9a-z]+]] = affine_map<(d0)[s0] -> (8, -d0 + s0)>
-// CHECK-DAG: #[[$DIV4:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 4)>
-// CHECK-DAG: #[[$DIV2:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 2)>
-#map0 = affine_map<(d0)[s0] -> (8, -d0 + s0)>
-#map1 = affine_map<(d0, d1) -> (4, -d0 + d1)>
-#map2 = affine_map<(d0, d1) -> (2, -d0 + d1)>
-#map3 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
-
-// CHECK-LABEL: func @dot
-// VERIFIER-ONLY-LABEL: func @dot
-func @dot(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<f32>)
- -> tensor<f32>
-{
- %cst = arith.constant 0.000000e+00 : f32
- %c8 = arith.constant 8 : index
- %c0 = arith.constant 0 : index
- %c4 = arith.constant 4 : index
- %c2 = arith.constant 2 : index
- %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
-
- // CHECK: scf.for %[[I:[0-9a-z]+]] =
- //
- // CHECK: %[[MR8:.*]] = affine.min #[[$MIN_REST8]](%[[I]])
- // Init tensor and pack.
- // CHECK: %[[INIT_PACKED_A:.*]] = linalg.init_tensor [2, 2, 2] : tensor<2x2x2xf32>
- // CHECK: %[[CAST_INIT_PACKED_A:.*]] = tensor.cast %[[INIT_PACKED_A]] : tensor<2x2x2xf32> to tensor<?x?x2xf32>
- // CHECK: %[[PACKED_A:.*]] = scf.for %[[II:[0-9a-z]+]] = {{.*}} iter_args(%{{.*}} = %[[CAST_INIT_PACKED_A]]) -> (tensor<?x?x2xf32>) {
- // CHECK: scf.for %[[III:[0-9a-z]+]] =
- // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}, 0] [1, 1, 2] [1, 1, 1] : tensor<2xf32> into tensor<?x?x2xf32>
- //
- // Init tensor and pack.
- // CHECK: %[[INIT_PACKED_B:.*]] = linalg.init_tensor [2, 2, 2] : tensor<2x2x2xf32>
- // CHECK: %[[CAST_INIT_PACKED_B:.*]] = tensor.cast %[[INIT_PACKED_B]] : tensor<2x2x2xf32> to tensor<?x?x2xf32>
- // CHECK: %[[PACKED_B:.*]] = scf.for %[[II_2:[0-9a-z]+]] = {{.*}} iter_args(%{{.*}} = %[[CAST_INIT_PACKED_B]]) -> (tensor<?x?x2xf32>) {
- // CHECK: scf.for %[[III_2:[0-9a-z]+]] =
- // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}, 0] [1, 1, 2] [1, 1, 1] : tensor<2xf32> into tensor<?x?x2xf32>
- // Compute.
- // CHECK: scf.for %[[II_3:[0-9a-z]+]] =
- // CHECK: scf.for %[[III_3:[0-9a-z]+]] = {{.*}} iter_args(%[[C:.*]] = %{{.*}}) -> (tensor<f32>) {
- // CHECK: %[[IDX0:.*]] = affine.apply #[[$DIV4]](%[[II_3]])
- // CHECK: %[[IDX1:.*]] = affine.apply #[[$DIV2]](%[[III_3]])
- // CHECK: %[[A:.*]] = tensor.extract_slice %[[PACKED_A]][%[[IDX0]], %[[IDX1]], 0] [1, 1, 2] [1, 1, 1] : tensor<?x?x2xf32> to tensor<2xf32>
- // CHECK: %[[IDX0_2:.*]] = affine.apply #[[$DIV4]](%[[II_3]])
- // CHECK: %[[IDX1_2:.*]] = affine.apply #[[$DIV2]](%[[III_3]])
- // CHECK: %[[B:.*]] = tensor.extract_slice %[[PACKED_B]][%[[IDX0_2]], %[[IDX1_2]], 0] [1, 1, 2] [1, 1, 1] : tensor<?x?x2xf32> to tensor<2xf32>
- // CHECK: linalg.dot ins(%[[A]], %[[B]] : tensor<2xf32>, tensor<2xf32>) outs(%[[C]] : tensor<f32>) -> tensor<f32>
-
- %1 = scf.for %arg3 = %c0 to %0 step %c8 iter_args(%arg4 = %arg2) -> (tensor<f32>) {
- %2 = affine.min #map0(%arg3)[%0]
- %3 = scf.for %arg5 = %c0 to %2 step %c4 iter_args(%arg6 = %arg4) -> (tensor<f32>) {
- %4 = affine.min #map1(%arg5, %2)
- %5 = scf.for %arg7 = %c0 to %4 step %c2 iter_args(%arg8 = %arg6) -> (tensor<f32>) {
- %6 = affine.min #map2(%arg7, %4)
- %7 = affine.apply #map3(%arg7, %arg5, %arg3)
- %8 = tensor.extract_slice %arg0[%7] [%6] [1] : tensor<?xf32> to tensor<?xf32>
- %9 = tensor.extract_slice %arg1[%7] [%6] [1] : tensor<?xf32> to tensor<?xf32>
- %10 = arith.subi %c2, %6 : index
- %11 = linalg.pad_tensor %8 low[%c0] high[%10] {
- ^bb0(%arg9: index): // no predecessors
- linalg.yield %cst : f32
- } : tensor<?xf32> to tensor<2xf32>
- %12 = linalg.pad_tensor %9 low[%c0] high[%10] {
- ^bb0(%arg9: index): // no predecessors
- linalg.yield %cst : f32
- } : tensor<?xf32> to tensor<2xf32>
- %13 = linalg.dot ins(%11, %12 : tensor<2xf32>, tensor<2xf32>) outs(%arg8 : tensor<f32>) -> tensor<f32>
- scf.yield %13 : tensor<f32>
- }
- scf.yield %5 : tensor<f32>
- }
- scf.yield %3 : tensor<f32>
- }
- return %1 : tensor<f32>
-}
-
-// -----
-
-// CHECK-LABEL: func @matmul_2d_tiling
-// VERIFIER-ONLY-LABEL: func @matmul_2d_tiling
-func @matmul_2d_tiling(%arg0: tensor<32x128xf32>, %arg1: tensor<128x64xf32>, %arg2: tensor<32x64xf32>) -> tensor<32x64xf32> {
- %c128 = arith.constant 128 : index
- %c64 = arith.constant 64 : index
- %c32 = arith.constant 32 : index
- %c16 = arith.constant 16 : index
- %cst = arith.constant 0.000000e+00 : f32
- %c2 = arith.constant 2 : index
- %c4 = arith.constant 4 : index
- %c0 = arith.constant 0 : index
- %1 = scf.for %arg3 = %c0 to %c32 step %c16 iter_args(%arg4 = %arg2) -> (tensor<32x64xf32>) {
- %2 = scf.for %arg5 = %c0 to %c64 step %c32 iter_args(%arg6 = %arg4) -> (tensor<32x64xf32>) {
- %3 = scf.for %arg7 = %c0 to %c128 step %c32 iter_args(%arg8 = %arg6) -> (tensor<32x64xf32>) {
- %4 = tensor.extract_slice %arg0[%arg3, %arg7] [16, 32] [1, 1] : tensor<32x128xf32> to tensor<16x32xf32>
- %5 = tensor.extract_slice %arg1[%arg7, %arg5] [32, 32] [1, 1] : tensor<128x64xf32> to tensor<32x32xf32>
- %6 = tensor.extract_slice %arg8[%arg3, %arg5] [16, 32] [1, 1] : tensor<32x64xf32> to tensor<16x32xf32>
- %7 = scf.for %arg9 = %c0 to %c16 step %c2 iter_args(%arg10 = %6) -> (tensor<16x32xf32>) {
- %10 = scf.for %arg11 = %c0 to %c32 step %c4 iter_args(%arg12 = %arg10) -> (tensor<16x32xf32>) {
- %11 = scf.for %arg13 = %c0 to %c32 step %c16 iter_args(%arg14 = %arg12) -> (tensor<16x32xf32>) {
- %12 = tensor.extract_slice %4[%arg9, %arg13] [2, 16] [1, 1] : tensor<16x32xf32> to tensor<2x16xf32>
- %13 = tensor.cast %12 : tensor<2x16xf32> to tensor<?x?xf32>
- %14 = tensor.extract_slice %5[%arg13, %arg11] [16, 4] [1, 1] : tensor<32x32xf32> to tensor<16x4xf32>
- %15 = tensor.cast %14 : tensor<16x4xf32> to tensor<?x?xf32>
- %16 = tensor.extract_slice %arg14[%arg9, %arg11] [2, 4] [1, 1] : tensor<16x32xf32> to tensor<2x4xf32>
- %17 = tensor.cast %16 : tensor<2x4xf32> to tensor<?x?xf32>
- %18 = linalg.pad_tensor %13 low[%c0, %c0] high[%c0, %c0] {
- ^bb0(%arg15: index, %arg16: index): // no predecessors
- linalg.yield %cst : f32
- } : tensor<?x?xf32> to tensor<2x16xf32>
- %19 = linalg.pad_tensor %15 low[%c0, %c0] high[%c0, %c0] {
- ^bb0(%arg15: index, %arg16: index): // no predecessors
- linalg.yield %cst : f32
- } : tensor<?x?xf32> to tensor<16x4xf32>
- %20 = linalg.pad_tensor %17 low[%c0, %c0] high[%c0, %c0] {
- ^bb0(%arg15: index, %arg16: index): // no predecessors
- linalg.yield %cst : f32
- } : tensor<?x?xf32> to tensor<2x4xf32>
- %21 = linalg.matmul ins(%18, %19 : tensor<2x16xf32>, tensor<16x4xf32>) outs(%20 : tensor<2x4xf32>) -> tensor<2x4xf32>
- %22 = tensor.cast %21 : tensor<2x4xf32> to tensor<?x?xf32>
- %23 = tensor.insert_slice %22 into %arg14[%arg9, %arg11] [%c2, %c4] [1, 1] : tensor<?x?xf32> into tensor<16x32xf32>
- scf.yield %23 : tensor<16x32xf32>
- }
- scf.yield %11 : tensor<16x32xf32>
- }
- scf.yield %10 : tensor<16x32xf32>
- }
- %8 = tensor.cast %7 : tensor<16x32xf32> to tensor<?x?xf32>
- %9 = tensor.insert_slice %8 into %arg8[%arg3, %arg5] [%c16, %c32] [1, 1] : tensor<?x?xf32> into tensor<32x64xf32>
- scf.yield %9 : tensor<32x64xf32>
- }
- scf.yield %3 : tensor<32x64xf32>
- }
- scf.yield %2 : tensor<32x64xf32>
- }
- return %1 : tensor<32x64xf32>
-}
diff --git a/mlir/test/Dialect/Linalg/pad.mlir b/mlir/test/Dialect/Linalg/pad.mlir
new file mode 100644
index 0000000000000..10e98a79fae5a
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/pad.mlir
@@ -0,0 +1,242 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns="test-pad-pattern pack-paddings=1,1,0 hoist-paddings=0,0,0" -cse -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<(d0) -> (7, -d0 + 12)>
+#map = affine_map<(d0) -> (7, -d0 + 12)>
+
+// CHECK: static_sizes_output_divisible
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32>
+func @static_sizes_output_divisible(%arg0: tensor<24x12xf32>,
+ %arg1: tensor<12x25xf32>,
+ %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C7:.*]] = arith.constant 7
+ %c0 = arith.constant 0 : index
+ %c12 = arith.constant 12 : index
+ %c25 = arith.constant 25 : index
+ %c24 = arith.constant 24 : index
+ %c7 = arith.constant 7 : index
+ %c5 = arith.constant 5 : index
+ %c4 = arith.constant 4 : index
+
+ // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
+ %0 = scf.for %arg3 = %c0 to %c24 step %c4 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) {
+
+ // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
+ %1 = scf.for %arg5 = %c0 to %c25 step %c5 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) {
+
+ // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG4:.*]] =
+ %2 = scf.for %arg7 = %c0 to %c12 step %c7 iter_args(%arg8 = %arg6) -> (tensor<24x25xf32>) {
+
+ // CHECK: %[[TS2:.*]] = affine.min #[[MAP0]](%[[IV2]])
+ %3 = affine.min #map(%arg7)
+
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[T2:.*]] = tensor.extract_slice %[[ARG4]]
+ %4 = tensor.extract_slice %arg0[%arg3, %arg7] [4, %3] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32>
+ %5 = tensor.extract_slice %arg1[%arg7, %arg5] [%3, 5] [1, 1] : tensor<12x25xf32> to tensor<?x5xf32>
+ %6 = tensor.extract_slice %arg8[%arg3, %arg5] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32>
+
+ // Check statically sized matmul inputs with partially divisible sizes are padded.
+ // CHECK: %[[V0:.*]] = arith.subi %[[C7]], %[[TS2]]
+ // CHECK: %[[T3:.*]] = linalg.pad_tensor %[[T0]] nofold
+ // CHECK-SAME: [%[[C0]], %[[C0]]]
+ // CHECK-SAME: [%[[C0]], %[[V0]]
+ // CHECK: %[[T4:.*]] = linalg.pad_tensor %[[T1]] nofold
+
+ // Check the statically sized matmul output with fully divisible sizes is not padded.
+ // CHECK: %[[T5:.*]] = linalg.matmul
+ // CHECK-SAME: ins(%[[T3]], %[[T4]] : tensor<4x7xf32>, tensor<7x5xf32>)
+ // CHECK-SAME: outs(%[[T2]] : tensor<4x5xf32>)
+ // CHECK: %[[T6:.*]] = tensor.insert_slice %[[T5]]
+ %7 = linalg.matmul {__internal_linalg_transform__ = "pad"} ins(%4, %5 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%6 : tensor<4x5xf32>) -> tensor<4x5xf32>
+ %8 = tensor.insert_slice %7 into %arg8[%arg3, %arg5] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
+
+ // CHECK: scf.yield %[[T6]]
+ scf.yield %8 : tensor<24x25xf32>
+ }
+ scf.yield %2 : tensor<24x25xf32>
+ }
+ scf.yield %1 : tensor<24x25xf32>
+ }
+ return %0 : tensor<24x25xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<(d0) -> (7, -d0 + 25)>
+#map = affine_map<(d0) -> (7, -d0 + 25)>
+
+// CHECK: static_sizes_input_divisible
+func @static_sizes_input_divisible(%arg0: tensor<24x12xf32>,
+ %arg1: tensor<12x25xf32>,
+ %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C7:.*]] = arith.constant 7
+ %c0 = arith.constant 0 : index
+ %c12 = arith.constant 12 : index
+ %c25 = arith.constant 25 : index
+ %c24 = arith.constant 24 : index
+ %c6 = arith.constant 6 : index
+ %c7 = arith.constant 7 : index
+ %c4 = arith.constant 4 : index
+
+ // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
+ %0 = scf.for %arg3 = %c0 to %c24 step %c4 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) {
+
+ // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
+ %1 = scf.for %arg5 = %c0 to %c25 step %c7 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) {
+
+ // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG4:.*]] =
+ %2 = scf.for %arg7 = %c0 to %c12 step %c6 iter_args(%arg8 = %arg6) -> (tensor<24x25xf32>) {
+ %3 = tensor.extract_slice %arg0[%arg3, %arg7] [4, 6] [1, 1] : tensor<24x12xf32> to tensor<4x6xf32>
+
+ // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]])
+ %4 = affine.min #map(%arg5)
+ %5 = tensor.extract_slice %arg1[%arg7, %arg5] [6, %4] [1, 1] : tensor<12x25xf32> to tensor<6x?xf32>
+
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG4]]
+ %6 = tensor.extract_slice %arg8[%arg3, %arg5] [4, %4] [1, 1] : tensor<24x25xf32> to tensor<4x?xf32>
+
+ // Check the statically sized matmul output with partially divisible sizes is padded.
+ // CHECK: %[[V0:.*]] = arith.subi %[[C7]], %[[TS1]]
+ // CHECK: %[[T1:.*]] = linalg.pad_tensor %[[T0]] low
+ // CHECK-SAME: [%[[C0]], %[[C0]]]
+ // CHECK-SAME: [%[[C0]], %[[V0]]
+
+ // CHECK: %[[T2:.*]] = linalg.matmul
+ // CHECK-SAME: outs(%[[T1]] : tensor<4x7xf32>)
+ // CHECK: %[[T3:.*]] = tensor.extract_slice %[[T2]]
+ // CHECK: %[[T4:.*]] = tensor.insert_slice %[[T3]]
+ %7 = linalg.matmul {__internal_linalg_transform__ = "pad"} ins(%3, %5 : tensor<4x6xf32>, tensor<6x?xf32>) outs(%6 : tensor<4x?xf32>) -> tensor<4x?xf32>
+ %8 = tensor.insert_slice %7 into %arg8[%arg3, %arg5] [4, %4] [1, 1] : tensor<4x?xf32> into tensor<24x25xf32>
+
+ // CHECK: scf.yield %[[T4]]
+ scf.yield %8 : tensor<24x25xf32>
+ }
+ scf.yield %2 : tensor<24x25xf32>
+ }
+ scf.yield %1 : tensor<24x25xf32>
+ }
+ return %0 : tensor<24x25xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<(d0)[s0] -> (5, -d0 + s0)>
+// CHECK-DAG: #[[MAP1:[0-9a-z]+]] = affine_map<(d0)[s0] -> (7, -d0 + s0)>
+// CHECK-DAG: #[[MAP2:[0-9a-z]+]] = affine_map<(d0)[s0] -> (6, -d0 + s0)>
+#map0 = affine_map<(d0)[s0] -> (5, -d0 + s0)>
+#map1 = affine_map<(d0)[s0] -> (6, -d0 + s0)>
+#map2 = affine_map<(d0)[s0] -> (7, -d0 + s0)>
+
+// CHECK: dynamic_sizes
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<?x?xf32>
+func @dynamic_sizes(%arg0: tensor<?x?xf32>,
+ %arg1: tensor<?x?xf32>,
+ %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1
+ // CHECK-DAG: %[[C5:.*]] = arith.constant 5
+ // CHECK-DAG: %[[C6:.*]] = arith.constant 6
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c6 = arith.constant 6 : index
+ %c7 = arith.constant 7 : index
+ %c5 = arith.constant 5 : index
+
+ // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]]
+ // CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG0]], %[[C1]]
+ // CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
+ %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %2 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+
+ // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
+ %3 = scf.for %arg3 = %c0 to %0 step %c5 iter_args(%arg4 = %arg2) -> (tensor<?x?xf32>) {
+
+ // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
+ %4 = scf.for %arg5 = %c0 to %2 step %c7 iter_args(%arg6 = %arg4) -> (tensor<?x?xf32>) {
+
+ // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG4:.*]] =
+ %5 = scf.for %arg7 = %c0 to %1 step %c6 iter_args(%arg8 = %arg6) -> (tensor<?x?xf32>) {
+
+ // CHECK: %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]])[%[[D0]]]
+ // CHECK: %[[TS2:.*]] = affine.min #[[MAP2]](%[[IV2]])[%[[D2]]]
+ // CHECK: %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]])[%[[D1]]]
+ %6 = affine.min #map0(%arg3)[%0]
+ %7 = affine.min #map1(%arg7)[%1]
+ %8 = tensor.extract_slice %arg0[%arg3, %arg7] [%6, %7] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ %9 = affine.min #map2(%arg5)[%2]
+ %10 = tensor.extract_slice %arg1[%arg7, %arg5] [%7, %9] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ %11 = tensor.extract_slice %arg8[%arg3, %arg5] [%6, %9] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+
+ // Check all matmul operands are padded.
+ // CHECK: %[[V0:.*]] = arith.subi %[[C5]], %[[TS0]]
+ // CHECK: %[[V1:.*]] = arith.subi %[[C6]], %[[TS2]]
+ // CHECK: %[[T3:.*]] = linalg.pad_tensor %{{.*}} nofold
+ // CHECK-SAME: [%[[C0]], %[[C0]]]
+ // CHECK-SAME: [%[[V0]], %[[V1]]
+ // CHECK: %[[T4:.*]] = linalg.pad_tensor %{{.*}} nofold
+ // CHECK: %[[T5:.*]] = linalg.pad_tensor %{{.*}} low
+
+ // Check the dynamic matmul has been erased.
+ // CHECK-NOT: = linalg.matmul {{.*}} tensor<?x?xf32>
+
+ // Check all padded matmul operands are statically sized.
+ // CHECK: %[[T6:.*]] = linalg.matmul
+ // CHECK-SAME: ins(%[[T3]], %[[T4]] : tensor<5x6xf32>, tensor<6x7xf32>)
+ // CHECK-SAME: outs(%[[T5]] : tensor<5x7xf32>)
+ // CHECK: %[[T7:.*]] = tensor.extract_slice %[[T6]][0, 0] [%[[TS0]], %[[TS1]]]
+ // CHECK: %[[T8:.*]] = tensor.insert_slice %[[T7]]
+ %12 = linalg.matmul {__internal_linalg_transform__ = "pad"} ins(%8, %10 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%11 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %13 = tensor.insert_slice %12 into %arg8[%arg3, %arg5] [%6, %9] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+
+ // CHECK: scf.yield %[[T8]]
+ scf.yield %13 : tensor<?x?xf32>
+ }
+ scf.yield %5 : tensor<?x?xf32>
+ }
+ scf.yield %4 : tensor<?x?xf32>
+ }
+ return %3 : tensor<?x?xf32>
+}
+
+// -----
+
+#map = affine_map<(d0) -> (7, -d0 + 12)>
+
+// CHECK: scalar_operand
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: f32
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+func @scalar_operand(%arg0: f32, %arg1: tensor<24x12xf32>) -> tensor<24x12xf32> {
+ %c0 = arith.constant 0 : index
+ %c12 = arith.constant 12 : index
+ %c24 = arith.constant 24 : index
+ %c7 = arith.constant 7 : index
+ %c4 = arith.constant 4 : index
+
+ // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
+ %0 = scf.for %arg2 = %c0 to %c24 step %c4 iter_args(%arg3 = %arg1) -> (tensor<24x12xf32>) {
+
+ // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG2:.*]] =
+ %1 = scf.for %arg4 = %c0 to %c12 step %c7 iter_args(%arg5 = %arg3) -> (tensor<24x12xf32>) {
+ %2 = affine.min #map(%arg4)
+
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[T1:.*]] = linalg.pad_tensor %[[T0]] nofold
+ %3 = tensor.extract_slice %arg5[%arg2, %arg4] [4, %2] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32>
+
+ // Check only the fill output operand is padded.
+ // CHECK: %[[T6:.*]] = linalg.fill(%[[ARG0]], %[[T1]]
+ %4 = linalg.fill(%arg0, %3) {__internal_linalg_transform__ = "pad"} : f32, tensor<4x?xf32> -> tensor<4x?xf32>
+ %5 = tensor.insert_slice %4 into %arg5[%arg2, %arg4] [4, %2] [1, 1] : tensor<4x?xf32> into tensor<24x12xf32>
+ scf.yield %5 : tensor<24x12xf32>
+ }
+ scf.yield %1 : tensor<24x12xf32>
+ }
+ return %0 : tensor<24x12xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
deleted file mode 100644
index 5c52d829ce566..0000000000000
--- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
+++ /dev/null
@@ -1,155 +0,0 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1,2 nofold-operands=0,1 tile-sizes=2,3,4" -canonicalize | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1 nofold-operands=0,1 tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE
-
-// CHECK-LABEL: func @matmul_tensors(
-// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xi8>
-// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<?x?xi8>
-// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
-func @matmul_tensors(
- %arg0: tensor<?x?xi8>, %arg1: tensor<?x?xi8>, %arg2: tensor<?x?xi32>)
- -> tensor<?x?xi32> {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xi32>) {
-// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xi32>) {
-// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?xi32>) {
-// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x?xi8> to tensor<?x?xi8>
-// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<?x?xi8> to tensor<?x?xi8>
-// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?xi32> to tensor<?x?xi32>
-
-// Dynamic op has been canonicalized away.
-// CHECK-NOT: linalg.matmul {{.*}} tensor<?x?xi8>
-
-// Padding injects static information.
-// CHECK: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
-// CHECK: : tensor<?x?xi8> to tensor<2x4xi8>
-// CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
-// CHECK: : tensor<?x?xi8> to tensor<4x3xi8>
-// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
-// CHECK: : tensor<?x?xi32> to tensor<2x3xi32>
-// CHECK: %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>)
-// CHECK-SAME: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32>
-// CHECK: %[[sTD:.*]] = tensor.extract_slice %[[pD]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<2x3xi32> to tensor<?x?xi32>
-// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?xi32> into tensor<?x?xi32>
-// CHECK: scf.yield %[[TD]] : tensor<?x?xi32>
-// CHECK: scf.yield %[[TD2]] : tensor<?x?xi32>
-// CHECK: scf.yield %[[TD1]] : tensor<?x?xi32>
- %0 = linalg.matmul {__internal_linalg_transform__ = "tile"}
- ins(%arg0, %arg1: tensor<?x?xi8>, tensor<?x?xi8>)
- outs(%arg2: tensor<?x?xi32>)
- -> tensor<?x?xi32>
-
-// CHECK: return %[[TD0]] : tensor<?x?xi32>
- return %0 : tensor<?x?xi32>
-}
-
-// CHECK-LABEL: func @generic_scalar_and_tensor(
-// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<?x?x?xf32>
-// CHECK-SAME: %[[VAL:[0-9a-z]+]]: f32) -> tensor<?x?x?xf32> {
-func @generic_scalar_and_tensor(
- %arg0: tensor<?x?x?xf32>, %arg1: f32)
- -> tensor<?x?x?xf32> {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?x?xf32>) {
-// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?x?xf32>) {
-// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?x?xf32>) {
-// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
-
-// Padding injects static information.
-// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] nofold low[%[[C0]], %[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}, %{{.*}}]
-// CHECK: : tensor<?x?x?xf32> to tensor<2x3x4xf32>
-// CHECK: %[[pD:.*]] = linalg.generic
-// CHECK-SAME: ins(%[[VAL]] : f32) outs(%[[pC]] : tensor<2x3x4xf32>)
-// CHECK: %[[sTD:.*]] = tensor.extract_slice %[[pD]][0, 0, 0] [%{{.*}}, %{{.*}}, %{{.*}}] [1, 1, 1] : tensor<2x3x4xf32> to tensor<?x?x?xf32>
-// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
-// CHECK: scf.yield %[[TD]] : tensor<?x?x?xf32>
-// CHECK: scf.yield %[[TD2]] : tensor<?x?x?xf32>
-// CHECK: scf.yield %[[TD1]] : tensor<?x?x?xf32>
- %0 = linalg.generic {
- indexing_maps = [ affine_map<(d0, d1, d2) -> ()>,
- affine_map<(d0, d1, d2) -> (d0, d1, d2)> ],
- iterator_types = ["parallel", "parallel", "parallel"]}
- {__internal_linalg_transform__ = "tile"}
- ins(%arg1 : f32)
- outs(%arg0: tensor<?x?x?xf32>) {
- ^bb(%0: f32, %1: f32) :
- linalg.yield %0 : f32
- } -> tensor<?x?x?xf32>
- return %0 : tensor<?x?x?xf32>
-}
-
-// CHECK-1DIM-TILE: func @matmul_tensors(
-// CHECK-1DIM-TILE: %[[TA:[0-9a-z]+]]: tensor<?x?xi8>
-// CHECK-1DIM-TILE: %[[TB:[0-9a-z]+]]: tensor<?x?xi8>
-// CHECK-1DIM-TILE: %[[TC:[0-9a-z]+]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
-// CHECK-1DIM-TILE-NOT: scf.for
-// CHECK-1DIM-TILE: linalg.matmul ins(%[[TA]], %[[TB]] : tensor<?x?xi8>, tensor<?x?xi8>) outs(%[[TC]] : tensor<?x?xi32>) -> tensor<?x?xi32>
-
-func @matmul_partially_padded_tensors(
- %arg0: tensor<?x8xi8>, %arg1: tensor<8x?xi8>, %arg2: tensor<?x?xi32>)
- -> tensor<?x?xi32> {
- %0 = linalg.matmul {__internal_linalg_transform__ = "tile"}
- ins(%arg0, %arg1: tensor<?x8xi8>, tensor<8x?xi8>)
- outs(%arg2: tensor<?x?xi32>)
- -> tensor<?x?xi32>
- return %0 : tensor<?x?xi32>
-}
-// CHECK-LABEL: func @matmul_partially_padded_tensors(
-// CHECK: linalg.matmul ins({{.*}}, {{.*}} : tensor<2x4xi8>, tensor<4x3xi8>) outs({{.*}} : tensor<2x3xi32>) -> tensor<2x3xi32>
-
-
-// Check only the the input operands are padded.
-// CHECK-1DIM-TILE: func @matmul_partially_padded_tensors(
-// CHECK-1DIM-TILE-SAME: %[[TA:[0-9a-z]+]]: tensor<?x8xi8>
-// CHECK-1DIM-TILE-SAME: %[[TB:[0-9a-z]+]]: tensor<8x?xi8>
-// CHECK-1DIM-TILE-SAME: %[[TC:[0-9a-z]+]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
-// CHECK-1DIM-TILE: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-1DIM-TILE: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xi32>) {
-// CHECK-1DIM-TILE: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xi32>) {
-// CHECK-1DIM-TILE: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x8xi8> to tensor<?x8xi8>
-// CHECK-1DIM-TILE: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<8x?xi8> to tensor<8x?xi8>
-// CHECK-1DIM-TILE: %[[sTC:.*]] = tensor.extract_slice %[[TC1]][{{.*}}] : tensor<?x?xi32> to tensor<?x?xi32>
-// CHECK-1DIM-TILE: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
-// CHECK-1DIM-TILE: : tensor<?x8xi8> to tensor<2x8xi8>
-// CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
-// CHECK-1DIM-TILE: : tensor<8x?xi8> to tensor<8x3xi8>
-// CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>)
-// CHECK-1DIM-TILE: outs(%[[sTC]] : tensor<?x?xi32>) -> tensor<?x?xi32>
-
-// Check that the tile-and-pad transformation actually introduces the padding
-// as requested, even if original operation already operates on static
-// shapes.
-// CHECK-LABEL: @pad_to_same_static_size
-func @pad_to_same_static_size(%arg0: tensor<2x3x4xf32>, %arg1: f32) -> tensor<2x3x4xf32> {
- // CHECK: %[[c0:.*]] = arith.constant 0 : index
- // CHECK-NOT: scf.for
- // CHECK: linalg.pad_tensor %{{.*}} nofold low[%[[c0]], %[[c0]], %[[c0]]] high[%[[c0]], %[[c0]], %[[c0]]]
- // CHECK: tensor<2x3x4xf32> to tensor<2x3x4xf32>
- %0 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1, d2) -> ()>,
- affine_map<(d0, d1, d2) -> (d0, d1, d2)> ],
- iterator_types = ["parallel", "parallel", "parallel"]}
- {__internal_linalg_transform__ = "tile"}
- ins(%arg1 : f32) outs(%arg0 : tensor<2x3x4xf32>) {
- ^bb0(%arg2: f32, %arg3: f32): // no predecessors
- linalg.yield %arg2 : f32
- } -> tensor<2x3x4xf32>
- return %0 : tensor<2x3x4xf32>
-}
-
-// CHECK-LABEL: @pad_static_divisible_size
-func @pad_static_divisible_size(%arg0: tensor<4x6x8xf32>, %arg1: f32) -> tensor<4x6x8xf32> {
- // CHECK: %[[c0:.*]] = arith.constant 0 : index
- // CHECK-COUNT-3: scf.for
- // CHECK: linalg.pad_tensor %{{.*}} nofold low[%[[c0]], %[[c0]], %[[c0]]] high[%[[c0]], %[[c0]], %[[c0]]]
- // CHECK: tensor<2x3x4xf32> to tensor<2x3x4xf32>
- %0 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1, d2) -> ()>,
- affine_map<(d0, d1, d2) -> (d0, d1, d2)> ],
- iterator_types = ["parallel", "parallel", "parallel"]}
- {__internal_linalg_transform__ = "tile"}
- ins(%arg1 : f32) outs(%arg0 : tensor<4x6x8xf32>) {
- ^bb0(%arg2: f32, %arg3: f32): // no predecessors
- linalg.yield %arg2 : f32
- } -> tensor<4x6x8xf32>
- return %0 : tensor<4x6x8xf32>
-}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
index ef3aab0e1baa4..249217fa58c41 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
@@ -143,14 +143,6 @@ struct TestLinalgCodegenStrategy
llvm::cl::init("")};
};
-// For now, just assume it is the zero of type.
-// In the future, it should be the zero of type + op.
-static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
- auto t = getElementTypeOrSelf(op.get());
- return b.create<arith::ConstantOp>(op.getOwner()->getLoc(), t,
- b.getZeroAttr(t));
-}
-
void TestLinalgCodegenStrategy::runStrategy(
LinalgTilingOptions tilingOptions,
LinalgTilingOptions registerTilingOptions,
@@ -196,6 +188,14 @@ void TestLinalgCodegenStrategy::runStrategy(
}
} // end anonymous namespace
+// For now, just assume it is the zero of type.
+// In the future, it should be the zero of type + op.
+static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
+ auto t = getElementTypeOrSelf(op.get());
+ return b.create<arith::ConstantOp>(op.getOwner()->getLoc(), t,
+ b.getZeroAttr(t));
+}
+
/// Apply transformations specified as patterns.
void TestLinalgCodegenStrategy::runOnFunction() {
if (!anchorFuncOpName.empty() && anchorFuncOpName != getFunction().getName())
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 78711310a8d88..863ac1693d6a5 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -93,9 +93,6 @@ struct TestLinalgTransforms
*this, "test-tile-scalarize-dynamic-dims",
llvm::cl::desc("Test tiling of dynamic dims by 1"),
llvm::cl::init(false)};
- Option<int> testHoistPadding{*this, "test-hoist-padding",
- llvm::cl::desc("Test hoist padding"),
- llvm::cl::init(0)};
Option<bool> testPadPattern{*this, "test-pad-pattern",
llvm::cl::desc("Test pad pattern"),
llvm::cl::init(false)};
@@ -112,14 +109,6 @@ struct TestLinalgTransforms
llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
"pad_tensor(subtensor)"),
llvm::cl::init(false)};
- ListOption<int64_t> paddedOperands{
- *this, "padded-operands",
- llvm::cl::desc("Operands to pad when test-tile-pattern"),
- llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
- ListOption<int64_t> nofoldOperands{
- *this, "nofold-operands",
- llvm::cl::desc("Operands to set nofold when test-tile-pattern"),
- llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
ListOption<int64_t> packPaddings{
*this, "pack-paddings",
llvm::cl::desc("Operand packing flags when test-pad-pattern"),
@@ -615,8 +604,6 @@ static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
static void applyTilePattern(FuncOp funcOp, std::string loopType,
ArrayRef<int64_t> tileSizes,
- ArrayRef<int64_t> paddedOperands,
- ArrayRef<int64_t> nofoldOperands,
ArrayRef<int64_t> peeledLoops,
bool scalarizeDynamicDims) {
MLIRContext *context = funcOp.getContext();
@@ -637,21 +624,6 @@ static void applyTilePattern(FuncOp funcOp, std::string loopType,
} else {
linalgTilingOptions.setTileSizes(tileSizes);
}
- if (!paddedOperands.empty()) {
- auto paddingFunc = [&](OpBuilder &b,
- OpOperand &opOperand) -> FailureOr<Value> {
- if (llvm::count(paddedOperands, opOperand.getOperandNumber()) == 0)
- return failure();
- return getNeutralOfLinalgOp(b, opOperand);
- };
- auto nofoldFunc = [&](OpOperand &opOperand) {
- if (llvm::count(nofoldOperands, opOperand.getOperandNumber()) != 0)
- return true;
- return false;
- };
- linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc);
- linalgTilingOptions.setPaddingNoFoldComputationFunction(nofoldFunc);
- }
tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
linalg::LinalgTilingPattern<linalg::GenericOp>>(
context, linalgTilingOptions,
@@ -808,24 +780,11 @@ void TestLinalgTransforms::runOnFunction() {
return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling,
skipPartial);
if (testTilePattern)
- return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
- nofoldOperands, peeledLoops,
+ return applyTilePattern(getFunction(), loopType, tileSizes, peeledLoops,
/*scalarizeDynamicDims=*/false);
if (testTileScalarizeDynamicDims)
- return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
- nofoldOperands,
+ return applyTilePattern(getFunction(), loopType, tileSizes,
/*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
- if (testHoistPadding) {
- getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
- PadTensorOp hoistedOp;
- FailureOr<Value> newResult = linalg::hoistPaddingOnTensors(
- padTensorOp, testHoistPadding, hoistedOp);
- if (succeeded(newResult)) {
- padTensorOp.getResult().replaceAllUsesWith(newResult.getValue());
- padTensorOp->erase();
- }
- });
- }
if (testPadPattern)
return applyPadPattern(getFunction(), packPaddings, hoistPaddings);
if (testInterchangePattern.hasValue())
More information about the Mlir-commits
mailing list