[Mlir-commits] [mlir] [mlir] Add apply_patterns.linalg.generalize_pack_unpack TD Op (PR #116373)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Nov 15 12:19:14 PST 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/116373

>From 4c8037bbbe082dbe45aa6b75a5fa12f0bf0e59a7 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 12 Nov 2024 20:33:29 +0000
Subject: [PATCH 1/2] [mlir] Add apply_patterns.linalg.generalize_pack_unpack
 TD Op

This PR simply creates `populateGeneralizePatterns` to collect the
following patterns:
  * `GeneralizeOuterUnitDimsPackOpPattern`,
  * `GeneralizeOuterUnitDimsUnPackOpPattern` (left as a TODO),

and wraps the newly created `populate*` method into a new Transform
Dialect Op: `apply_patterns.linalg.generalize_pack_unpack`. This allows
creating more involved e2e compilation pipelines for `tensor.pack` and
`tensor.unpack` Ops. I will require this this Op, in a forthcoming
PR that will build on top of #115698.

No new tests are added. Instead, I've re-used existing tests:
  * "generalize-tensor-pack.mlir".
---
 .../Linalg/TransformOps/LinalgTransformOps.td        | 12 ++++++++++++
 .../mlir/Dialect/Linalg/Transforms/Transforms.h      |  9 +++++++--
 .../Linalg/TransformOps/LinalgTransformOps.cpp       |  5 +++++
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp    |  5 +++++
 mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir |  3 +--
 mlir/test/Dialect/Linalg/lit.local.cfg               |  2 ++
 mlir/test/Dialect/Linalg/td/generalize_pack.mlir     | 12 ++++++++++++
 7 files changed, 44 insertions(+), 4 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/lit.local.cfg
 create mode 100644 mlir/test/Dialect/Linalg/td/generalize_pack.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 25a98a16960f37..c77013f35f6d26 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -41,6 +41,18 @@ def ApplyEraseUnnecessaryInputsPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyGeneralizeTensorPackUnpackPatternsOp
+    : Op<Transform_Dialect, "apply_patterns.linalg.generalize_pack_unpack",
+         [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect patterns to generalize tensor.pack and tensor.unpack (i.e. to
+    decompose it into e.g. tensor::PadOp, linalg::transposeOp etc). Requires
+    all outer dims to be unit.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyFoldUnitExtentDimsViaReshapesPatternsOp : Op<Transform_Dialect,
     "apply_patterns.linalg.fold_unit_extent_dims_via_reshapes",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 89e9a3b70d2ab3..0b55a76f884331 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1516,8 +1516,8 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
 };
 
 /// Rewrites a tensor::PackOp into a sequence of:
-///   * tensor::PadOp + linalg::TransposeOp +
-///     tensor::EmptyOp + tensor::InsertSliceOp ops.
+///   * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
+///     tensor::InsertSliceOp ops.
 ///
 /// Required that all the outer dims of the input tensor::PackOp are 1.
 ///
@@ -1683,6 +1683,11 @@ void populateLinalgGenericOpsSpecializationPatterns(
 void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
                                           PatternBenefit benefit = 1);
 
+/// Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g.
+/// tensor.pad, linalg.transpose, tensor.{insert|extract}_slice. Require all
+/// outer dims to be unit.
+void populateGeneralizePatterns(RewritePatternSet &patterns);
+
 /// Populates patterns to transform linalg.conv_2d_xxx operations into
 /// linalg.generic (for img2col packing) and linalg.matmul.
 /// \see rewriteInIm2Col for more details.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 88e116bce7f595..88fd54bef8e8bb 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -229,6 +229,11 @@ void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
   linalg::populateEraseUnnecessaryInputsPatterns(patterns);
 }
 
+void transform::ApplyGeneralizeTensorPackUnpackPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  linalg::populateGeneralizePatterns(patterns);
+}
+
 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   linalg::ControlDropUnitDims options;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index ed9ebca4f306a4..c9eac663675599 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1618,3 +1618,8 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
       DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
       patterns.getContext(), benefit);
 }
+
+void linalg::populateGeneralizePatterns(RewritePatternSet &patterns) {
+  // TODO: Add and test patterns for tensor.unpack
+  patterns.add<GeneralizeOuterUnitDimsPackOpPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index f4b1d9a55f0914..c0c6f82a38179b 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-tensor-pack"  %s | FileCheck %s
-
+// RUN: mlir-opt  --transform-preload-library='transform-library-paths=%p/td/generalize_pack.mlir' -split-input-file  --transform-interpreter --test-transform-dialect-erase-schedule %s | FileCheck %s
 
 func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32> {
   %c8 = arith.constant 8 : index
diff --git a/mlir/test/Dialect/Linalg/lit.local.cfg b/mlir/test/Dialect/Linalg/lit.local.cfg
new file mode 100644
index 00000000000000..62743008a3e3a6
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/lit.local.cfg
@@ -0,0 +1,2 @@
+# Skip the directory with input TD sequences
+config.excludes = ["td"]
diff --git a/mlir/test/Dialect/Linalg/td/generalize_pack.mlir b/mlir/test/Dialect/Linalg/td/generalize_pack.mlir
new file mode 100644
index 00000000000000..62e5b779ff361a
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/td/generalize_pack.mlir
@@ -0,0 +1,12 @@
+module @transforms attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %pack = transform.structured.match ops{["tensor.pack"]} in %module : (!transform.any_op) -> !transform.any_op
+
+    %1 = transform.get_parent_op %pack {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %1 {
+      transform.apply_patterns.linalg.generalize_pack_unpack
+    } : !transform.any_op
+
+    transform.yield
+  }
+}

>From 861953f38778a2eeebc1f055f4e04e2d068c4f9d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 15 Nov 2024 20:18:47 +0000
Subject: [PATCH 2/2] fixup! [mlir] Add
 apply_patterns.linalg.generalize_pack_unpack TD Op

Generalize -> Decompose
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  9 ++++---
 .../Dialect/Linalg/Transforms/Transforms.h    |  6 ++---
 .../TransformOps/LinalgTransformOps.cpp       |  4 ++--
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  8 +++----
 .../Linalg/generalize-tensor-pack-tile.mlir   |  2 +-
 .../Linalg/generalize-tensor-pack.mlir        |  2 +-
 .../Linalg/generalize-tensor-unpack-tile.mlir |  2 +-
 .../Linalg/generalize-tensor-unpack.mlir      |  2 +-
 ...neralize_pack.mlir => decompose_pack.mlir} |  2 +-
 .../Linalg/CPU/pack-dynamic-inner-tile.mlir   |  2 +-
 .../Dialect/Linalg/TestLinalgTransforms.cpp   | 24 +++++++++----------
 11 files changed, 31 insertions(+), 32 deletions(-)
 rename mlir/test/Dialect/Linalg/td/{generalize_pack.mlir => decompose_pack.mlir} (88%)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index c77013f35f6d26..532d33962a7b35 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -41,13 +41,12 @@ def ApplyEraseUnnecessaryInputsPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
-def ApplyGeneralizeTensorPackUnpackPatternsOp
-    : Op<Transform_Dialect, "apply_patterns.linalg.generalize_pack_unpack",
+def ApplyDecomposeTensorPackUnpackPatternsOp
+    : Op<Transform_Dialect, "apply_patterns.linalg.decompose_pack_unpack",
          [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Collect patterns to generalize tensor.pack and tensor.unpack (i.e. to
-    decompose it into e.g. tensor::PadOp, linalg::transposeOp etc). Requires
-    all outer dims to be unit.
+    Collect patterns to decompose tensor.pack and tensor.unpack into e.g.
+    tensor::PadOp, linalg::transposeOp Ops. Requires all outer dims to be unit.
   }];
 
   let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 0b55a76f884331..51967f83fee377 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1548,7 +1548,7 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
 ///     into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1]
 ///     : tensor<2x?xf32> into tensor<1x1x2x?xf32>
 /// ```
-struct GeneralizeOuterUnitDimsPackOpPattern
+struct DecomposeOuterUnitDimsPackOpPattern
     : public OpRewritePattern<tensor::PackOp> {
   using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(tensor::PackOp packOp,
@@ -1558,7 +1558,7 @@ struct GeneralizeOuterUnitDimsPackOpPattern
 /// Rewrites a tensor::UnPackOp into a sequence of rank-reduced extract_slice op
 /// + transpose op + insert_slice op, where the tensor::UnPackOp has outer dims
 /// being all 1s.
-struct GeneralizeOuterUnitDimsUnPackOpPattern
+struct DecomposeOuterUnitDimsUnPackOpPattern
     : public OpRewritePattern<tensor::UnPackOp> {
   using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp,
@@ -1686,7 +1686,7 @@ void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
 /// Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g.
 /// tensor.pad, linalg.transpose, tensor.{insert|extract}_slice. Require all
 /// outer dims to be unit.
-void populateGeneralizePatterns(RewritePatternSet &patterns);
+void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns);
 
 /// Populates patterns to transform linalg.conv_2d_xxx operations into
 /// linalg.generic (for img2col packing) and linalg.matmul.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 88fd54bef8e8bb..75a696b5419e81 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -229,9 +229,9 @@ void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
   linalg::populateEraseUnnecessaryInputsPatterns(patterns);
 }
 
-void transform::ApplyGeneralizeTensorPackUnpackPatternsOp::populatePatterns(
+void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  linalg::populateGeneralizePatterns(patterns);
+  linalg::populateDecomposePackUnpackPatterns(patterns);
 }
 
 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index c9eac663675599..d92543d7264625 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1138,7 +1138,7 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
   return perm;
 }
 
-LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
+LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
     tensor::PackOp packOp, PatternRewriter &rewriter) const {
   // TODO: support the case that outer dimensions are not all 1s. A
   // tensor.expand_shape will be generated in this case.
@@ -1239,7 +1239,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   return success();
 }
 
-LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
+LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
     tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const {
   int64_t srcRank = unpackOp.getSourceRank();
   int64_t destRank = unpackOp.getDestRank();
@@ -1619,7 +1619,7 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
       patterns.getContext(), benefit);
 }
 
-void linalg::populateGeneralizePatterns(RewritePatternSet &patterns) {
+void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
   // TODO: Add and test patterns for tensor.unpack
-  patterns.add<GeneralizeOuterUnitDimsPackOpPattern>(patterns.getContext());
+  patterns.add<DecomposeOuterUnitDimsPackOpPattern>(patterns.getContext());
 }
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
index 1fae311467bcf4..e855aeaf80b1db 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file --transform-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-pack"  %s | FileCheck %s
+// RUN: mlir-opt -split-input-file --transform-interpreter --canonicalize --test-linalg-transform-patterns="test-decompose-tensor-pack"  %s | FileCheck %s
 
 func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8x32xf32>) -> tensor<1x1x4x8x8x32xf32> {
   %0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x128x64xf32> -> tensor<1x1x4x8x8x32xf32>
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index c0c6f82a38179b..efb14f632a53c2 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt  --transform-preload-library='transform-library-paths=%p/td/generalize_pack.mlir' -split-input-file  --transform-interpreter --test-transform-dialect-erase-schedule %s | FileCheck %s
+// RUN: mlir-opt  --transform-preload-library='transform-library-paths=%p/td/decompose_pack.mlir' -split-input-file  --transform-interpreter --test-transform-dialect-erase-schedule %s | FileCheck %s
 
 func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32> {
   %c8 = arith.constant 8 : index
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
index c15859d898ec14..6d9709caf70937 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file --transform-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-unpack"  %s | FileCheck %s
+// RUN: mlir-opt -split-input-file --transform-interpreter --canonicalize --test-linalg-transform-patterns="test-decompose-tensor-unpack"  %s | FileCheck %s
 
 func.func @KCRSsr_to_KCRS(%arg0: tensor<1x1x4x8x8x32xf32>, %arg1: tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> {
   %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x4x8x8x32xf32> -> tensor<1x1x128x64xf32>
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
index 153ce68b8f086c..8b15873473a972 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-tensor-unpack"  %s | FileCheck %s
+// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-decompose-tensor-unpack"  %s | FileCheck %s
 
 func.func @simple_KCRSsr_to_KCRS(%arg0: tensor<1x1x1x1x8x32xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32> {
   %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x1x1x8x32xf32> -> tensor<1x1x32x8xf32>
diff --git a/mlir/test/Dialect/Linalg/td/generalize_pack.mlir b/mlir/test/Dialect/Linalg/td/decompose_pack.mlir
similarity index 88%
rename from mlir/test/Dialect/Linalg/td/generalize_pack.mlir
rename to mlir/test/Dialect/Linalg/td/decompose_pack.mlir
index 62e5b779ff361a..9c8d481c775798 100644
--- a/mlir/test/Dialect/Linalg/td/generalize_pack.mlir
+++ b/mlir/test/Dialect/Linalg/td/decompose_pack.mlir
@@ -4,7 +4,7 @@ module @transforms attributes { transform.with_named_sequence } {
 
     %1 = transform.get_parent_op %pack {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %1 {
-      transform.apply_patterns.linalg.generalize_pack_unpack
+      transform.apply_patterns.linalg.decompose_pack_unpack
     } : !transform.any_op
 
     transform.yield
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/pack-dynamic-inner-tile.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/pack-dynamic-inner-tile.mlir
index bec1b9a4e9d82f..0428ada86041da 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/pack-dynamic-inner-tile.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/pack-dynamic-inner-tile.mlir
@@ -1,6 +1,6 @@
 // DEFINE: %{compile} =  mlir-opt %s \
 // DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule |\
-// DEFINE:  mlir-opt --test-linalg-transform-patterns="test-generalize-tensor-pack"\
+// DEFINE:  mlir-opt --test-linalg-transform-patterns="test-decompose-tensor-pack"\
 // DEFINE:    --test-transform-dialect-erase-schedule \
 // DEFINE:    -one-shot-bufferize="bufferize-function-boundaries" \
 // DEFINE:    -buffer-deallocation-pipeline="private-function-dynamic-ownership" \
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 5899f56da73459..c65e68eaf31f09 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -74,13 +74,13 @@ struct TestLinalgTransforms
       *this, "test-generalize-pad-tensor",
       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
       llvm::cl::init(false)};
-  Option<bool> testGeneralizeTensorPackOp{
-      *this, "test-generalize-tensor-pack",
+  Option<bool> testDecomposeTensorPackOp{
+      *this, "test-decompose-tensor-pack",
       llvm::cl::desc("Test transform that generalizes pack ops into a sequence "
                      "of tensor and Linalg ops"),
       llvm::cl::init(false)};
-  Option<bool> testGeneralizeTensorUnPackOp{
-      *this, "test-generalize-tensor-unpack",
+  Option<bool> testDecomposeTensorUnPackOp{
+      *this, "test-decompose-tensor-unpack",
       llvm::cl::desc(
           "Test transform that generalizes unpack ops into a sequence "
           "of tensor and Linalg ops"),
@@ -172,15 +172,15 @@ static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
-static void applyGeneralizeTensorPackPatterns(func::FuncOp funcOp) {
+static void applyDecomposeTensorPackPatterns(func::FuncOp funcOp) {
   RewritePatternSet patterns(funcOp.getContext());
-  patterns.add<GeneralizeOuterUnitDimsPackOpPattern>(funcOp.getContext());
+  patterns.add<DecomposeOuterUnitDimsPackOpPattern>(funcOp.getContext());
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
-static void applyGeneralizeTensorUnPackPatterns(func::FuncOp funcOp) {
+static void applyDecomposeTensorUnPackPatterns(func::FuncOp funcOp) {
   RewritePatternSet patterns(funcOp.getContext());
-  patterns.add<GeneralizeOuterUnitDimsUnPackOpPattern>(funcOp.getContext());
+  patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(funcOp.getContext());
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
@@ -237,10 +237,10 @@ void TestLinalgTransforms::runOnOperation() {
     return applyLinalgToVectorPatterns(getOperation());
   if (testGeneralizePadTensor)
     return applyGeneralizePadTensorPatterns(getOperation());
-  if (testGeneralizeTensorPackOp)
-    return applyGeneralizeTensorPackPatterns(getOperation());
-  if (testGeneralizeTensorUnPackOp)
-    return applyGeneralizeTensorUnPackPatterns(getOperation());
+  if (testDecomposeTensorPackOp)
+    return applyDecomposeTensorPackPatterns(getOperation());
+  if (testDecomposeTensorUnPackOp)
+    return applyDecomposeTensorUnPackPatterns(getOperation());
   if (testSwapSubTensorPadTensor)
     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
   if (testBubbleUpExtractSliceOpPattern)



More information about the Mlir-commits mailing list