[Mlir-commits] [mlir] [mlir] Add apply_patterns.linalg.generalize_pack_unpack TD Op (PR #116373)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sat Nov 16 03:18:52 PST 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/116373

>From 138ccd94e1d2eda74df411b29600d9b612981b67 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 12 Nov 2024 20:33:29 +0000
Subject: [PATCH] [mlir] Add apply_patterns.linalg.generalize_pack_unpack
 Transform Dialect Op

This PR introduces populateGeneralizePatterns, which collects the
following patterns:

  * `GeneralizeOuterUnitDimsPackOpPattern`,
  * `GeneralizeOuterUnitDimsUnPackOpPattern` (currently a TODO).

These patterns are wrapped in a new Transform Dialect Op:
`apply_patterns.linalg.generalize_pack_unpack`. This Op facilitates
creating more involved end-to-end compilation pipelines for
`tensor.pack` and `tensor.unpack` operations. It will be required in an
upcoming PR building on top of #115698.

No new tests are added in this PR. Instead, existing tests from:

  * "generalize-tensor-pack.mlir"

are reused. To achieve this:

  * I've updated the test to use
    `transform.apply_patterns.linalg.generalize_pack_unpack` instead of
    the flag
    `--test-linalg-transform-patterns="test-generalize-tensor-pack"`,
    avoiding artificial tests solely for the TD Op.
  * The TD sequence is saved to a new file, "generalize_pack.mlir", and
    pre-loaded using the option:
    `--transform-preload-library='transform-library-paths=%p/td/generalize_pack.mlir'`
    This avoids duplicating the sequence for every "split" in the input
    file.
  * Added lit.local.cfg to exclude the "test/Dialect/Linalg/td"
    directory from test discovery, ensuring "generalize_pack.mlir" is
    not treated as a test file.
---
 .../Linalg/TransformOps/LinalgTransformOps.td        | 12 ++++++++++++
 .../mlir/Dialect/Linalg/Transforms/Transforms.h      |  9 +++++++--
 .../Linalg/TransformOps/LinalgTransformOps.cpp       |  5 +++++
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp    |  5 +++++
 mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir |  3 +--
 mlir/test/Dialect/Linalg/lit.local.cfg               |  2 ++
 mlir/test/Dialect/Linalg/td/generalize_pack.mlir     | 12 ++++++++++++
 7 files changed, 44 insertions(+), 4 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/lit.local.cfg
 create mode 100644 mlir/test/Dialect/Linalg/td/generalize_pack.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f256af2f6b12b8..42057d8d0c9105 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -41,6 +41,18 @@ def ApplyEraseUnnecessaryInputsPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyGeneralizeTensorPackUnpackPatternsOp
+    : Op<Transform_Dialect, "apply_patterns.linalg.generalize_pack_unpack",
+         [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect patterns to generalize tensor.pack and tensor.unpack (i.e. to
+    decompose it into e.g. tensor::PadOp, linalg::transposeOp etc). Requires
+    all outer dims to be unit.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyFoldUnitExtentDimsViaReshapesPatternsOp : Op<Transform_Dialect,
     "apply_patterns.linalg.fold_unit_extent_dims_via_reshapes",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 89e9a3b70d2ab3..0b55a76f884331 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1516,8 +1516,8 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
 };
 
 /// Rewrites a tensor::PackOp into a sequence of:
-///   * tensor::PadOp + linalg::TransposeOp +
-///     tensor::EmptyOp + tensor::InsertSliceOp ops.
+///   * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
+///     tensor::InsertSliceOp ops.
 ///
 /// Required that all the outer dims of the input tensor::PackOp are 1.
 ///
@@ -1683,6 +1683,11 @@ void populateLinalgGenericOpsSpecializationPatterns(
 void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
                                           PatternBenefit benefit = 1);
 
+/// Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g.
+/// tensor.pad, linalg.transpose, tensor.{insert|extract}_slice. Require all
+/// outer dims to be unit.
+void populateGeneralizePatterns(RewritePatternSet &patterns);
+
 /// Populates patterns to transform linalg.conv_2d_xxx operations into
 /// linalg.generic (for img2col packing) and linalg.matmul.
 /// \see rewriteInIm2Col for more details.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1956fc634ef395..a00c609779c3a7 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -229,6 +229,11 @@ void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
   linalg::populateEraseUnnecessaryInputsPatterns(patterns);
 }
 
+void transform::ApplyGeneralizeTensorPackUnpackPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  linalg::populateGeneralizePatterns(patterns);
+}
+
 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   linalg::ControlDropUnitDims options;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index ed9ebca4f306a4..c9eac663675599 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1618,3 +1618,8 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
       DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
       patterns.getContext(), benefit);
 }
+
+void linalg::populateGeneralizePatterns(RewritePatternSet &patterns) {
+  // TODO: Add and test patterns for tensor.unpack
+  patterns.add<GeneralizeOuterUnitDimsPackOpPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index f4b1d9a55f0914..0a664d4347fa9a 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-tensor-pack"  %s | FileCheck %s
-
+// RUN: mlir-opt  --transform-preload-library='transform-library-paths=%p/td/generalize_pack.mlir' -split-input-file  --transform-interpreter %s | FileCheck %s
 
 func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32> {
   %c8 = arith.constant 8 : index
diff --git a/mlir/test/Dialect/Linalg/lit.local.cfg b/mlir/test/Dialect/Linalg/lit.local.cfg
new file mode 100644
index 00000000000000..62743008a3e3a6
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/lit.local.cfg
@@ -0,0 +1,2 @@
+# Skip the directory with input TD sequences
+config.excludes = ["td"]
diff --git a/mlir/test/Dialect/Linalg/td/generalize_pack.mlir b/mlir/test/Dialect/Linalg/td/generalize_pack.mlir
new file mode 100644
index 00000000000000..62e5b779ff361a
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/td/generalize_pack.mlir
@@ -0,0 +1,12 @@
+module @transforms attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %pack = transform.structured.match ops{["tensor.pack"]} in %module : (!transform.any_op) -> !transform.any_op
+
+    %1 = transform.get_parent_op %pack {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %1 {
+      transform.apply_patterns.linalg.generalize_pack_unpack
+    } : !transform.any_op
+
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list