[Mlir-commits] [mlir] 1b2c8f1 - [mlir][linalg] Extract `GeneralizePadOpPattern` into a standalone transformation (#117329)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 26 00:11:19 PST 2024
Author: Andrzej WarzyĆski
Date: 2024-11-26T08:11:15Z
New Revision: 1b2c8f104f9c6f26500ab608060bbc6b7f40f5e1
URL: https://github.com/llvm/llvm-project/commit/1b2c8f104f9c6f26500ab608060bbc6b7f40f5e1
DIFF: https://github.com/llvm/llvm-project/commit/1b2c8f104f9c6f26500ab608060bbc6b7f40f5e1.diff
LOG: [mlir][linalg] Extract `GeneralizePadOpPattern` into a standalone transformation (#117329)
Currently, `GeneralizePadOpPattern` is grouped under
`populatePadOpVectorizationPatterns`. However, as noted in #111349, this
transformation "decomposes" rather than "vectorizes" `tensor.pad`. As
such, it functions as:
* a vectorization _pre-processing_ transformation, not
* a vectorization transformation itself.
To clarify its purpose, this PR turns `GeneralizePadOpPattern` into a
standalone transformation by:
* introducing a dedicated `populateDecomposePadPatterns` method,
* adding a `apply_patterns.linalg.decompose_pad` Transform Dialect Op,
* removing it from `populatePadOpVectorizationPatterns`.
In addition, to better reflect its role, it is renamed as "decomposition"
rather then "generalization". This is in line with the recent renaming
of similar ops, i.e. tensor.pack/tensor.unpack Ops in #116439.
Added:
mlir/test/Dialect/Linalg/decompose-pad-tensor.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index e3084530bd11b5..dc10f3a1c58ae3 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -52,6 +52,17 @@ def ApplyDecomposeTensorPackUnpackPatternsOp
let assemblyFormat = "attr-dict";
}
+def ApplyDecomposeTensorPadPatternsOp
+ : Op<Transform_Dialect, "apply_patterns.linalg.decompose_pad",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collect patterns to decompose tensor.pad into e.g. tensor::EmptyOp,
+ linalg::FillOp and tensor::InsertSliceOp.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyFoldUnitExtentDimsViaReshapesPatternsOp : Op<Transform_Dialect,
"apply_patterns.linalg.fold_unit_extent_dims_via_reshapes",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 51967f83fee377..3c160d55a38e75 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1503,8 +1503,8 @@ using OptimizeCopyFn =
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
/// InsertSliceOp. For now, only constant padding values are supported.
-struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
- GeneralizePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
+struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
+ DecomposePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<tensor::PadOp>(context, benefit) {}
LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override;
@@ -1688,6 +1688,10 @@ void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
/// outer dims to be unit.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns);
+/// Populates patterns to decompose tensor.pad into e.g.
+/// tensor.empty, linalg.fill, tensor.insert_slice.
+void populateDecomposePadPatterns(RewritePatternSet &patterns);
+
/// Populates patterns to transform linalg.conv_2d_xxx operations into
/// linalg.generic (for img2col packing) and linalg.matmul.
/// \see rewriteInIm2Col for more details.
diff --git a/mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp b/mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp
index 5bb79d4bc84e2b..b0ca0ca13d0624 100644
--- a/mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp
+++ b/mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp
@@ -25,5 +25,7 @@ using namespace mlir;
//===----------------------------------------------------------------------===//
void mlir::populateTensorToLinalgPatterns(RewritePatternSet &patterns) {
- patterns.add<mlir::linalg::GeneralizePadOpPattern>(patterns.getContext());
+ // TODO: Add the remaining patterns, e.g. to decompose Pack/Unpack Ops.
+ // Alternatively, delete this file.
+ patterns.add<mlir::linalg::DecomposePadOpPattern>(patterns.getContext());
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ada80deacfdbfe..e08be7d2ebd6ae 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -234,6 +234,11 @@ void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
linalg::populateDecomposePackUnpackPatterns(patterns);
}
+void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ linalg::populateDecomposePadPatterns(patterns);
+}
+
void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
linalg::ControlDropUnitDims options;
@@ -3491,8 +3496,12 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
// Add misc. vectorization patterns (e.g. for tensor.insert_slice)
linalg::populateInsertSliceVectorizationPatterns(patterns);
- if (getVectorizePadding())
+ if (getVectorizePadding()) {
linalg::populatePadOpVectorizationPatterns(patterns);
+ // This creates an alternative path for lowering tensor.pad - by
+ // decomposing it into e.g. linalg.fill.
+ linalg::populateDecomposePadPatterns(patterns);
+ }
vector::populateVectorStepLoweringPatterns(patterns);
TrackingListener listener(state, *this);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index d92543d7264625..c3e176299317ef 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -921,7 +921,7 @@ LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
/// Filling `dest` using FillOp constant padding value if possible.
/// Otherwise, generate a tensor::GenerateOp.
-Value GeneralizePadOpPattern::createFillOrGenerateOp(
+Value DecomposePadOpPattern::createFillOrGenerateOp(
RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
const SmallVector<Value> &dynSizes) const {
auto padValue = padOp.getConstantPaddingValue();
@@ -938,8 +938,8 @@ Value GeneralizePadOpPattern::createFillOrGenerateOp(
}
LogicalResult
-GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
- PatternRewriter &rewriter) const {
+DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const {
// Given an OpFoldResult, return an index-typed value.
auto getIdxValue = [&](OpFoldResult ofr) {
if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
@@ -1623,3 +1623,7 @@ void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
// TODO: Add and test patterns for tensor.unpack
patterns.add<DecomposeOuterUnitDimsPackOpPattern>(patterns.getContext());
}
+
+void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) {
+ patterns.add<DecomposePadOpPattern>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 23b46a2ee55f8d..06bb6c0fb1cac9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2770,12 +2770,6 @@ void mlir::linalg::populateInsertSliceVectorizationPatterns(
void mlir::linalg::populatePadOpVectorizationPatterns(
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
- // TODO: The following pattern implements "decomposition" and
- // optional "vectorization". Seperate "decomposition" into a sepereate
- // pre-processing pattern group.
- patterns.add<GeneralizePadOpPattern>(patterns.getContext(), baseBenefit);
-
- // Try these specialized patterns first before resorting to the generic one.
patterns.add<PadOpVectorizationWithTransferReadPattern,
PadOpVectorizationWithTransferWritePattern,
PadOpVectorizationWithInsertSlicePattern>(
diff --git a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir b/mlir/test/Dialect/Linalg/decompose-pad-tensor.mlir
similarity index 98%
rename from mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir
rename to mlir/test/Dialect/Linalg/decompose-pad-tensor.mlir
index 2beab31b613d54..184361dfb30dfd 100644
--- a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pad-tensor.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-pad-tensor" %s | FileCheck %s
+// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-decompose-pad-tensor" %s | FileCheck %s
// CHECK-LABEL: func @generalize_pad_tensor_static_shape(
// CHECK-SAME: %[[IN:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
diff --git a/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir
index 640de85cc5f12e..41e480648177f5 100644
--- a/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir
@@ -202,6 +202,8 @@ module attributes {transform.with_named_sequence} {
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
+ // TODO: Split into two tests, one for each pattern
+ transform.apply_patterns.linalg.decompose_pad
transform.apply_patterns.linalg.pad_vectorization
} : !transform.op<"func.func">
transform.yield
@@ -236,6 +238,8 @@ module attributes {transform.with_named_sequence} {
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
+ // TODO: Split into two tests, one for each pattern
+ transform.apply_patterns.linalg.decompose_pad
transform.apply_patterns.linalg.pad_vectorization
} : !transform.op<"func.func">
transform.yield
@@ -270,6 +274,8 @@ module attributes {transform.with_named_sequence} {
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
+ // TODO: Split into two tests, one for each pattern
+ transform.apply_patterns.linalg.decompose_pad
transform.apply_patterns.linalg.pad_vectorization
} : !transform.op<"func.func">
transform.yield
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index c65e68eaf31f09..25aec75c3c14ad 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -70,8 +70,8 @@ struct TestLinalgTransforms
llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
"in vector.contract form"),
llvm::cl::init(false)};
- Option<bool> testGeneralizePadTensor{
- *this, "test-generalize-pad-tensor",
+ Option<bool> testDecomposePadTensor{
+ *this, "test-decompose-pad-tensor",
llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
llvm::cl::init(false)};
Option<bool> testDecomposeTensorPackOp{
@@ -166,9 +166,9 @@ static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
-static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
+static void applyDecomposePadPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
- patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
+ patterns.add<DecomposePadOpPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
@@ -235,8 +235,8 @@ void TestLinalgTransforms::runOnOperation() {
return applyVectorTransferForwardingPatterns(getOperation());
if (testGenericToVectorPattern)
return applyLinalgToVectorPatterns(getOperation());
- if (testGeneralizePadTensor)
- return applyGeneralizePadTensorPatterns(getOperation());
+ if (testDecomposePadTensor)
+ return applyDecomposePadPatterns(getOperation());
if (testDecomposeTensorPackOp)
return applyDecomposeTensorPackPatterns(getOperation());
if (testDecomposeTensorUnPackOp)
More information about the Mlir-commits
mailing list