[Mlir-commits] [mlir] [mlir][vector] Add a new TD Op for patterns leveraging ShapeCastOp (PR #110525)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Sep 30 08:53:52 PDT 2024


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/110525

Adds a new Transform Dialect Op that collects patters for dropping unit
dims from various Ops:
  * `transform.apply_patterns.vector.drop_unit_dims_with_shape_cast`.

It excludes patterns for vector.transfer Ops - these are collected
under:
  * `apply_patterns.vector.rank_reducing_subview_patterns`,

and use ShapeCastOp _and_ SubviewOp to reduce the rank (and to eliminate
unit dims).

This new TD Ops allows us to test the "ShapeCast folder" pattern in
isolation. I've extracted the only test that I could find for that
folder from "vector-transforms.mlir" and moved it to a dedicated file:
"shape-cast-folder.mlir". I also added a test case with scalable
vectors.


>From f17421f36555663e0a11a269ccf8fa33ae65a8c7 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 29 Sep 2024 15:55:22 +0100
Subject: [PATCH] [mlir][vector] Add a new TD Op for patterns leveraging
 ShapeCastOp

Adds a new Transform Dialect Op that collects patters for dropping unit
dims from various Ops:
  * `transform.apply_patterns.vector.drop_unit_dims_with_shape_cast`.

It excludes patterns for vector.transfer Ops - these are collected
under:
  * `apply_patterns.vector.rank_reducing_subview_patterns`,

and use ShapeCastOp _and_ SubviewOp to reduce the rank (and to eliminate
unit dims).

This new TD Ops allows us to test the "ShapeCast folder" pattern in
isolation. I've extracted the only test that I could find for that
folder from "vector-transforms.mlir" and moved it to a dedicated file:
"shape-cast-folder.mlir". I also added a test case with scalable
vectors.
---
 .../Vector/TransformOps/VectorTransformOps.td | 16 ++++++++
 .../TransformOps/VectorTransformOps.cpp       |  5 +++
 .../Vector/Transforms/VectorTransforms.cpp    |  7 +++-
 .../Dialect/Vector/shape-cast-folder.mlir     | 38 +++++++++++++++++++
 .../Dialect/Vector/vector-transforms.mlir     |  9 -----
 5 files changed, 64 insertions(+), 11 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/shape-cast-folder.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index aad2eab83dbd38..c973eca0132a92 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -68,6 +68,22 @@ def ApplyRankReducingSubviewPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyDropUnitDimWithShapeCastPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.drop_unit_dims_with_shape_cast",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+     Apply vector patterns to fold unit dims with vector.shape_cast Ops:
+      - DropUnitDimFromElementwiseOps
+      - DropUnitDimsFromScfForOp
+      - DropUnitDimsFromTransposeOp
+
+    Excludes patterns for vector.transfer Ops. This is complemented by
+    shape_cast folding patterns.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.transfer_permutation_patterns",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 3ae70ace3934cd..241e83e234d621 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -85,6 +85,11 @@ void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
 }
 
+void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateDropUnitDimWithShapeCastPatterns(patterns);
+}
+
 void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::populateVectorBitCastLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 8fcef54f12edfa..09b6bfd0a15e3b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2056,8 +2056,11 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
 
 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
-               ShapeCastOpFolder, DropUnitDimsFromScfForOp>(
+  // TODO: Consider either including DropInnerMostUnitDimsTransferRead,
+  // DropInnerMostUnitDimsTransferWrite or better naming to distinguish this
+  // and populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
+  patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
+               DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
       patterns.getContext(), benefit);
 }
 
diff --git a/mlir/test/Dialect/Vector/shape-cast-folder.mlir b/mlir/test/Dialect/Vector/shape-cast-folder.mlir
new file mode 100644
index 00000000000000..9550c5c4ae0561
--- /dev/null
+++ b/mlir/test/Dialect/Vector/shape-cast-folder.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+///----------------------------------------------------------------------------------------
+/// [Pattern: ShapeCastOpFolder]
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func @fixed_width
+//  CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
+//   CHECK-NOT: vector.shape_cast
+//       CHECK: return %[[A0]] : vector<2x4xf32>
+func.func @fixed_width(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
+  %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
+  %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
+  return %1 : vector<2x4xf32>
+}
+
+// CHECK-LABEL: func @scalable
+//  CHECK-SAME: %[[A0:.*0]]: vector<2x[4]xf32>
+//   CHECK-NOT: vector.shape_cast
+//       CHECK: return %[[A0]] : vector<2x[4]xf32>
+func.func @scalable(%arg0 : vector<2x[4]xf32>) -> vector<2x[4]xf32> {
+  %0 = vector.shape_cast %arg0 : vector<2x[4]xf32> to vector<[8]xf32>
+  %1 = vector.shape_cast %0 : vector<[8]xf32> to vector<2x[4]xf32>
+  return %1 : vector<2x[4]xf32>
+}
+
+// ============================================================================
+//  TD sequence
+// ============================================================================
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.drop_unit_dims_with_shape_cast
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index eda6a5cc40d999..89e8ca1d93109c 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -184,15 +184,6 @@ func.func @vector_transfers(%arg0: index, %arg1: index) {
   return
 }
 
-// CHECK-LABEL: func @cancelling_shape_cast_ops
-//  CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
-//       CHECK: return %[[A0]] : vector<2x4xf32>
-func.func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
-  %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
-  %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
-  return %1 : vector<2x4xf32>
-}
-
 // CHECK-LABEL: func @elementwise_unroll
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32>, %[[ARG1:.*]]: memref<4x4xf32>)
 //       CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index



More information about the Mlir-commits mailing list