[Mlir-commits] [mlir] 1f5e826 - [mlir][vector] Add a new TD Op for patterns leveraging ShapeCastOp (#110525)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 1 02:08:47 PDT 2024
Author: Andrzej WarzyĆski
Date: 2024-10-01T10:08:43+01:00
New Revision: 1f5e8263b920f591c517a5dc562cccad39dd6ec7
URL: https://github.com/llvm/llvm-project/commit/1f5e8263b920f591c517a5dc562cccad39dd6ec7
DIFF: https://github.com/llvm/llvm-project/commit/1f5e8263b920f591c517a5dc562cccad39dd6ec7.diff
LOG: [mlir][vector] Add a new TD Op for patterns leveraging ShapeCastOp (#110525)
Adds a new Transform Dialect Op that collects patters for dropping unit
dims from various Ops:
* `transform.apply_patterns.vector.drop_unit_dims_with_shape_cast`.
It excludes patterns for vector.transfer Ops - these are collected
under:
* `apply_patterns.vector.rank_reducing_subview_patterns`,
and use ShapeCastOp _and_ SubviewOp to reduce the rank (and to eliminate
unit dims).
This new TD Ops allows us to test the "ShapeCast folder" pattern in
isolation. I've extracted the only test that I could find for that
folder from "vector-transforms.mlir" and moved it to a dedicated file:
"shape-cast-folder.mlir". I also added a test case with scalable
vectors.
Changes in VectorTransforms.cpp are not needed (added a comment with
a TODO + ordered the patterns alphabetically). I am Including them here
to avoid a separate PR.
Added:
mlir/test/Dialect/Vector/shape-cast-folder.mlir
Modified:
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index aad2eab83dbd38..c973eca0132a92 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -68,6 +68,22 @@ def ApplyRankReducingSubviewPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyDropUnitDimWithShapeCastPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.drop_unit_dims_with_shape_cast",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Apply vector patterns to fold unit dims with vector.shape_cast Ops:
+ - DropUnitDimFromElementwiseOps
+ - DropUnitDimsFromScfForOp
+ - DropUnitDimsFromTransposeOp
+
+ Excludes patterns for vector.transfer Ops. This is complemented by
+ shape_cast folding patterns.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.transfer_permutation_patterns",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 3ae70ace3934cd..241e83e234d621 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -85,6 +85,11 @@ void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
}
+void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateDropUnitDimWithShapeCastPatterns(patterns);
+}
+
void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorBitCastLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 8fcef54f12edfa..7f6b2303f86e10 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2056,8 +2056,13 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
- ShapeCastOpFolder, DropUnitDimsFromScfForOp>(
+ // TODO: Consider either:
+ // * including DropInnerMostUnitDimsTransferRead and
+ // DropInnerMostUnitDimsTransferWrite, or
+ // * better naming to distinguish this and
+ // populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
+ patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
+ DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
patterns.getContext(), benefit);
}
diff --git a/mlir/test/Dialect/Vector/shape-cast-folder.mlir b/mlir/test/Dialect/Vector/shape-cast-folder.mlir
new file mode 100644
index 00000000000000..9550c5c4ae0561
--- /dev/null
+++ b/mlir/test/Dialect/Vector/shape-cast-folder.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+///----------------------------------------------------------------------------------------
+/// [Pattern: ShapeCastOpFolder]
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func @fixed_width
+// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
+// CHECK-NOT: vector.shape_cast
+// CHECK: return %[[A0]] : vector<2x4xf32>
+func.func @fixed_width(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
+ %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
+ return %1 : vector<2x4xf32>
+}
+
+// CHECK-LABEL: func @scalable
+// CHECK-SAME: %[[A0:.*0]]: vector<2x[4]xf32>
+// CHECK-NOT: vector.shape_cast
+// CHECK: return %[[A0]] : vector<2x[4]xf32>
+func.func @scalable(%arg0 : vector<2x[4]xf32>) -> vector<2x[4]xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x[4]xf32> to vector<[8]xf32>
+ %1 = vector.shape_cast %0 : vector<[8]xf32> to vector<2x[4]xf32>
+ return %1 : vector<2x[4]xf32>
+}
+
+// ============================================================================
+// TD sequence
+// ============================================================================
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.drop_unit_dims_with_shape_cast
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index eda6a5cc40d999..89e8ca1d93109c 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -184,15 +184,6 @@ func.func @vector_transfers(%arg0: index, %arg1: index) {
return
}
-// CHECK-LABEL: func @cancelling_shape_cast_ops
-// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
-// CHECK: return %[[A0]] : vector<2x4xf32>
-func.func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
- %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
- %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
- return %1 : vector<2x4xf32>
-}
-
// CHECK-LABEL: func @elementwise_unroll
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32>, %[[ARG1:.*]]: memref<4x4xf32>)
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
More information about the Mlir-commits
mailing list