[Mlir-commits] [mlir] [mlir][vector] Add a new TD Op for patterns leveraging ShapeCastOp (PR #110525)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 30 08:54:28 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
Adds a new Transform Dialect Op that collects patters for dropping unit
dims from various Ops:
* `transform.apply_patterns.vector.drop_unit_dims_with_shape_cast`.
It excludes patterns for vector.transfer Ops - these are collected
under:
* `apply_patterns.vector.rank_reducing_subview_patterns`,
and use ShapeCastOp _and_ SubviewOp to reduce the rank (and to eliminate
unit dims).
This new TD Ops allows us to test the "ShapeCast folder" pattern in
isolation. I've extracted the only test that I could find for that
folder from "vector-transforms.mlir" and moved it to a dedicated file:
"shape-cast-folder.mlir". I also added a test case with scalable
vectors.
---
Full diff: https://github.com/llvm/llvm-project/pull/110525.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+16)
- (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+5)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+5-2)
- (added) mlir/test/Dialect/Vector/shape-cast-folder.mlir (+38)
- (modified) mlir/test/Dialect/Vector/vector-transforms.mlir (-9)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index aad2eab83dbd38..c973eca0132a92 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -68,6 +68,22 @@ def ApplyRankReducingSubviewPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyDropUnitDimWithShapeCastPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.drop_unit_dims_with_shape_cast",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Apply vector patterns to fold unit dims with vector.shape_cast Ops:
+ - DropUnitDimFromElementwiseOps
+ - DropUnitDimsFromScfForOp
+ - DropUnitDimsFromTransposeOp
+
+ Excludes patterns for vector.transfer Ops. This is complemented by
+ shape_cast folding patterns.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.transfer_permutation_patterns",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 3ae70ace3934cd..241e83e234d621 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -85,6 +85,11 @@ void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
}
+void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateDropUnitDimWithShapeCastPatterns(patterns);
+}
+
void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorBitCastLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 8fcef54f12edfa..09b6bfd0a15e3b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2056,8 +2056,11 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
- ShapeCastOpFolder, DropUnitDimsFromScfForOp>(
+ // TODO: Consider either including DropInnerMostUnitDimsTransferRead,
+ // DropInnerMostUnitDimsTransferWrite or better naming to distinguish this
+ // and populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
+ patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
+ DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
patterns.getContext(), benefit);
}
diff --git a/mlir/test/Dialect/Vector/shape-cast-folder.mlir b/mlir/test/Dialect/Vector/shape-cast-folder.mlir
new file mode 100644
index 00000000000000..9550c5c4ae0561
--- /dev/null
+++ b/mlir/test/Dialect/Vector/shape-cast-folder.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+///----------------------------------------------------------------------------------------
+/// [Pattern: ShapeCastOpFolder]
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func @fixed_width
+// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
+// CHECK-NOT: vector.shape_cast
+// CHECK: return %[[A0]] : vector<2x4xf32>
+func.func @fixed_width(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
+ %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
+ return %1 : vector<2x4xf32>
+}
+
+// CHECK-LABEL: func @scalable
+// CHECK-SAME: %[[A0:.*0]]: vector<2x[4]xf32>
+// CHECK-NOT: vector.shape_cast
+// CHECK: return %[[A0]] : vector<2x[4]xf32>
+func.func @scalable(%arg0 : vector<2x[4]xf32>) -> vector<2x[4]xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x[4]xf32> to vector<[8]xf32>
+ %1 = vector.shape_cast %0 : vector<[8]xf32> to vector<2x[4]xf32>
+ return %1 : vector<2x[4]xf32>
+}
+
+// ============================================================================
+// TD sequence
+// ============================================================================
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.drop_unit_dims_with_shape_cast
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index eda6a5cc40d999..89e8ca1d93109c 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -184,15 +184,6 @@ func.func @vector_transfers(%arg0: index, %arg1: index) {
return
}
-// CHECK-LABEL: func @cancelling_shape_cast_ops
-// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
-// CHECK: return %[[A0]] : vector<2x4xf32>
-func.func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
- %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
- %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
- return %1 : vector<2x4xf32>
-}
-
// CHECK-LABEL: func @elementwise_unroll
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32>, %[[ARG1:.*]]: memref<4x4xf32>)
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
``````````
</details>
https://github.com/llvm/llvm-project/pull/110525
More information about the Mlir-commits
mailing list