[Mlir-commits] [mlir] ad6700b - [Transform] Support more case for the transform pad operation
Quentin Colombet
llvmlistbot at llvm.org
Mon May 8 06:47:09 PDT 2023
Author: Quentin Colombet
Date: 2023-05-08T15:40:38+02:00
New Revision: ad6700b520c742b4748dac02d0fb12b6d47b25f9
URL: https://github.com/llvm/llvm-project/commit/ad6700b520c742b4748dac02d0fb12b6d47b25f9
DIFF: https://github.com/llvm/llvm-project/commit/ad6700b520c742b4748dac02d0fb12b6d47b25f9.diff
LOG: [Transform] Support more case for the transform pad operation
Don't choke on `outs` arguments that are not produced by `tensor.empty` or
`tensor.extract_slice`.
When the `outs` argument has a static shape we have all the necessary
information to proceed with the padding.
This makes the `transform.structured.pad` a little bit more resilient.
Differential Revision: https://reviews.llvm.org/D150112
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/transform-op-pad.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4a5c69c5bc06..90d2aa521157 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -96,14 +96,6 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
// Fail if `currOpOperand` is not defined by an ExtractSliceOp or EmptyOp.
auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
auto emptyOp = currOpOperand->get().getDefiningOp<tensor::EmptyOp>();
- if (!sliceOp && !emptyOp) {
- // TODO: may want to add support for going through loop iter args.
- // This is not strictly necessary as we can pad before hoisting but it would
- // make the system more resilient to minor transformation reordering.
- LLVM_DEBUG(DBGS() << "--not defined by an extractSlice or emptyOp\n");
- return rewriter.notifyMatchFailure(
- opToPad, "not defined by an extractSlice or emptyOp");
- }
llvm::SmallBitVector droppedDims;
SmallVector<OpFoldResult> mixedSizes;
@@ -111,10 +103,19 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
// Compute the dropped dimensions if `sliceOp` is rank-reducing.
droppedDims = sliceOp.getDroppedDims();
mixedSizes = sliceOp.getMixedSizes();
- }
- if (emptyOp) {
+ } else if (emptyOp) {
mixedSizes = emptyOp.getMixedSizes();
droppedDims.resize(mixedSizes.size());
+ } else if (hasStaticShape) {
+ mixedSizes = getAsIndexOpFoldResult(rewriter.getContext(), shape);
+ droppedDims.resize(mixedSizes.size());
+ } else {
+ // TODO: may want to add support for going through loop iter args.
+ // This is not strictly necessary as we can pad before hoisting but it would
+ // make the system more resilient to minor transformation reordering.
+ LLVM_DEBUG(DBGS() << "--not defined by an extractSlice or emptyOp\n");
+ return rewriter.notifyMatchFailure(
+ opToPad, "not defined by an extractSlice or emptyOp");
}
LLVM_DEBUG(llvm::interleaveComma(mixedSizes, DBGS() << "--mixedSizes: ");
llvm::dbgs() << "\n");
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index 685f70648b04..3677c1afeecb 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -127,6 +127,7 @@ transform.sequence failures(propagate) {
// -----
+// CHECK-LABEL: @pad(
func.func @pad(%arg0: tensor<24x12xf32>,
%arg1: tensor<12x25xf32>,
%arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
@@ -147,3 +148,63 @@ transform.sequence failures(suppress) {
pack_paddings=[1, 1, 0]
}
}
+
+// -----
+
+// Check that the padding can be applied even when the output argument of the
+// linalg op is not produced by an empty op or an extract_slice op.
+
+// CHECK-DAG: #[[$MAP_MIN:.*]] = affine_map<(d0) -> (-d0 + 2044, 16)>
+// CHECK-DAG: #[[$MAP_C0:.*]] = affine_map<() -> (0)>
+// CHECK-DAG: #[[$MAP_TO_16:.*]] = affine_map<(d0) -> (-d0 + 16)>
+// CHECK-LABEL: @outs_not_produced_by_empty_or_extract_slice(
+// CHECK-SAME: %[[A:[^: ]*]]: tensor<128x2044xf32>,
+// CHECK-SAME: %[[B:[^: ]*]]: tensor<2044x128xf32>)
+func.func @outs_not_produced_by_empty_or_extract_slice(%a : tensor<128x2044xf32>, %b : tensor<2044x128xf32>) -> tensor<128x128xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<128x128xf32>
+ %9 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c2044 = arith.constant 2044 : index
+ // CHECK: scf.for %[[ARG3:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %{{.*}})
+ %10 = scf.for %arg3 = %c0 to %c2044 step %c16 iter_args(%arg4 = %9) -> (tensor<128x128xf32>) {
+ // CHECK: %[[MIN:.*]] = affine.min #[[$MAP_MIN]](%[[ARG3]])
+ %11 = affine.min affine_map<(d0) -> (-d0 + 2044, 16)>(%arg3)
+ // CHECK: %[[A_SLICE:.*]] = tensor.extract_slice %[[A]]
+ // CHECK: %[[B_SLICE:.*]] = tensor.extract_slice %[[B]]
+ %extracted_slice_2 = tensor.extract_slice %a[0, %arg3] [128, %11] [1, 1] : tensor<128x2044xf32> to tensor<128x?xf32>
+ %extracted_slice_3 = tensor.extract_slice %b[%arg3, 0] [%11, 128] [1, 1] : tensor<2044x128xf32> to tensor<?x128xf32>
+ // CHECK-DAG: %[[CST:.*]] = arith.constant 0.
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+
+ // CHECK-DAG: %[[ZERO:.*]] = affine.apply #[[$MAP_C0]]()
+ // CHECK-DAG: %[[TO_16:.*]] = affine.apply #[[$MAP_TO_16]](%[[MIN]])
+ // CHECK: %[[PADDED_A_SLICE:.*]] = tensor.pad %[[A_SLICE]] nofold low[%[[C0]], %[[C0]]] high[%[[ZERO]], %[[TO_16]]]
+ // CHECK: tensor.yield %[[CST]]
+ // CHECK: %[[PADDED_B_SLICE:.*]] = tensor.pad %[[B_SLICE]] nofold
+ // The output shape is already padded, so actually we shouldn't
+ // add anything to the upper bound.
+ // CHECK: %[[ZERO0:.*]] = affine.apply #[[$MAP_C0]]()
+ // CHECK: %[[ZERO1:.*]] = affine.apply #[[$MAP_C0]]()
+ // CHECK: %[[PADDED_ARG4:.*]] = tensor.pad %[[ARG4]] nofold low[{{.*}}] high[%[[ZERO0]], %[[ZERO1]]]
+
+ // CHECK: %[[T5:.*]] = linalg.matmul
+ // CHECK-SAME: ins(%[[PADDED_A_SLICE]], %[[PADDED_B_SLICE]] : tensor<128x16xf32>, tensor<16x128xf32>)
+ // CHECK-SAME: outs(%[[PADDED_ARG4]] : tensor<128x128xf32>)
+ %res = linalg.matmul ins(%extracted_slice_2, %extracted_slice_3 : tensor<128x?xf32>, tensor<?x128xf32>) outs(%arg4 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ scf.yield %res : tensor<128x128xf32>
+ }
+ return %10 : tensor<128x128xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = transform.structured.pad %0 {
+ padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
+ padding_dimensions=[0, 1, 2],
+ pack_paddings=[1, 1, 1]
+ }
+}
More information about the Mlir-commits
mailing list