[Mlir-commits] [mlir] [mlir][vector] Add pattern to drop unit dims from vector.transpose (PR #102017)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Aug 7 08:49:36 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/102017

>From d41ea3c0a203521a351366210cb4b0a1bc895447 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 5 Aug 2024 16:51:44 +0000
Subject: [PATCH 1/3] [mlir][vector] Add pattern to drop unit dims from
 vector.transpose

Example:

BEFORE:
```mlir
%transpose = vector.transpose %vector, [3, 0, 1, 2]
  : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
```

AFTER:
```mlir
%dropDims = vector.shape_cast %vector
  : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%transpose = vector.transpose %0, [1, 0]
  : vector<4x[4]xf32> to vector<[4]x4xf32>
%restoreDims = vector.shape_cast %transpose
  : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
```
---
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   |  5 ++
 .../Vector/Transforms/VectorTransforms.cpp    | 70 ++++++++++++++++++-
 .../Vector/vector-transfer-flatten.mlir       | 33 +++++++++
 3 files changed, 106 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 40e04b76593a0..67c36bfa06ded 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -120,6 +120,11 @@ inline auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
   };
 }
 
+/// Returns an iterator over the dims (inc scalability) of a VectorType.
+inline auto getDims(VectorType vType) {
+  return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
+}
+
 /// A wrapper for getMixedSizes for vector.transfer_read and
 /// vector.transfer_write Ops (for source and destination, respectively).
 ///
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6777e589795c8..6b39cef7899d9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1720,6 +1720,72 @@ struct DropUnitDimFromElementwiseOps final
   }
 };
 
+/// A pattern to drop unit dims from vector.transpose.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %transpose = vector.transpose %vector, [3, 0, 1, 2]
+///    : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %dropDims = vector.shape_cast %vector
+///    : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+///  %transpose = vector.transpose %0, [1, 0]
+///    : vector<4x[4]xf32> to vector<[4]x4xf32>
+///  %restoreDims = vector.shape_cast %transpose
+///    : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+///  ```
+struct DropUnitDimsFromTransposeOp final
+    : OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = op.getSourceVectorType();
+    VectorType sourceTypeWithoutUnitDims =
+        dropNonScalableUnitDimFromType(sourceType);
+
+    if (sourceType == sourceTypeWithoutUnitDims)
+      return failure();
+
+    // Construct a map from dimIdx -> number of dims dropped before dimIdx.
+    auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
+    SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
+    int64_t droppedDims = 0;
+    for (auto [i, dim] : llvm::enumerate(sourceDims)) {
+      droppedDimsBefore[i] = droppedDims;
+      if (dim == std::make_tuple(1, false))
+        ++droppedDims;
+    }
+
+    // Drop unit dims from transpose permutation.
+    ArrayRef<int64_t> perm = op.getPermutation();
+    SmallVector<int64_t> newPerm;
+    for (int64_t idx : perm) {
+      if (sourceDims[idx] == std::make_tuple(1, false))
+        continue;
+      newPerm.push_back(idx - droppedDimsBefore[idx]);
+    }
+
+    auto loc = op.getLoc();
+    // Drop the unit dims via shape_cast.
+    auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
+        loc, sourceTypeWithoutUnitDims, op.getVector());
+    // Create the new transpose.
+    auto tranposeWithoutUnitDims =
+        rewriter.create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
+    // Restore the unit dims via shape cast.
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        op, op.getResultVectorType(), tranposeWithoutUnitDims);
+
+    return failure();
+  }
+};
+
 /// Pattern to eliminate redundant zero-constants added to reduction operands.
 /// It's enough for there to be one initial zero value, so we can eliminate the
 /// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1924,8 +1990,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
 
 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
-      patterns.getContext(), benefit);
+  patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
+               ShapeCastOpFolder>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 9d16aa46a9f2a..222a05ff70d02 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -700,3 +700,36 @@ func.func @negative_out_of_bound_transfer_write(
 }
 // CHECK:     func.func @negative_out_of_bound_transfer_write
 // CHECK-NOT:   memref.collapse_shape
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// [Pattern: DropUnitDimsFromTransposeOp]
+/// TODO: Move to a dedicated file - there's no "flattening" in the following tests
+///----------------------------------------------------------------------------------------
+
+func.func @transpose_with_internal_unit_dims(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x1x1x4xf32> {
+  %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+  return %0 : vector<[4]x1x1x4xf32>
+}
+
+// CHECK-LABEL: func.func @transpose_with_internal_unit_dims(
+// CHECK-SAME:                                               %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
+// CHECK-NEXT:    %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+// CHECK-NEXT:    %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+// CHECK-NEXT:    %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+// CHECK-NEXT:    return %[[RESTORE_DIMS]] : vector<[4]x1x1x4xf32>
+
+// -----
+
+func.func @transpose_with_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x1x1x1x4x1xf32> {
+  %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
+  return %0 : vector<[4]x1x1x1x4x1xf32>
+}
+
+// CHECK-LABEL: func.func @transpose_with_units_dims_before_and_after(
+// CHECK-SAME:                                                        %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
+// CHECK-NEXT:    %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
+// CHECK-NEXT:    %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+// CHECK-NEXT:    %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x1x4x1xf32>
+// CHECK-NEXT:    return %[[RESTORE_DIMS]] : vector<[4]x1x1x1x4x1xf32>

>From 3680f20f8ed8d9ca20e3632bb60b64221dc55f2e Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 7 Aug 2024 12:02:35 +0000
Subject: [PATCH 2/3] Update tests

---
 .../Vector/vector-transfer-flatten.mlir       | 35 ++++++++++++-------
 1 file changed, 23 insertions(+), 12 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 222a05ff70d02..937dbf22bb713 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -708,9 +708,9 @@ func.func @negative_out_of_bound_transfer_write(
 /// TODO: Move to a dedicated file - there's no "flattening" in the following tests
 ///----------------------------------------------------------------------------------------
 
-func.func @transpose_with_internal_unit_dims(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x1x1x4xf32> {
-  %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
-  return %0 : vector<[4]x1x1x4xf32>
+func.func @transpose_with_internal_unit_dims(%vec: vector<1x1x4x[4]xf32>) -> vector<[4]x1x1x4xf32> {
+  %res = vector.transpose %vec, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+  return %res : vector<[4]x1x1x4xf32>
 }
 
 // CHECK-LABEL: func.func @transpose_with_internal_unit_dims(
@@ -722,14 +722,25 @@ func.func @transpose_with_internal_unit_dims(%vector: vector<1x1x4x[4]xf32>) ->
 
 // -----
 
-func.func @transpose_with_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x1x1x1x4x1xf32> {
-  %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
-  return %0 : vector<[4]x1x1x1x4x1xf32>
+func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> vector<1x1x4x2x[1]xf32>
+{
+  %res = vector.transpose %vec, [4, 1, 3, 2, 0]  : vector<[1]x1x2x4x1xf32> to vector<1x1x4x2x[1]xf32>
+  return %res: vector<1x1x4x2x[1]xf32>
 }
 
-// CHECK-LABEL: func.func @transpose_with_units_dims_before_and_after(
-// CHECK-SAME:                                                        %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
-// CHECK-NEXT:    %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
-// CHECK-NEXT:    %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
-// CHECK-NEXT:    %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x1x4x1xf32>
-// CHECK-NEXT:    return %[[RESTORE_DIMS]] : vector<[4]x1x1x1x4x1xf32>
+// CHECK-LABEL: func.func @transpose_with_scalable_unit_dims(
+// CHECK-SAME:                                               %[[VEC:.*]]: vector<[1]x1x2x4x1xf32>)
+// CHECK-NEXT:    %[[DROP_DIMS:.*]] = vector.shape_cast %[[VEC]] : vector<[1]x1x2x4x1xf32> to vector<[1]x2x4xf32>
+// CHECK-NEXT:    %[[TRANSPOSE:.*]] = vector.transpose %[[DROP_DIMS]], [2, 1, 0] : vector<[1]x2x4xf32> to vector<4x2x[1]xf32>
+// CHECK-NEXT:    %[[RESTORE_DIMS:.*]] = vector.shape_cast %[[TRANSPOSE]] : vector<4x2x[1]xf32> to vector<1x1x4x2x[1]xf32>
+// CHECK-NEXT:    return %[[RESTORE_DIMS]] : vector<1x1x4x2x[1]xf32>
+
+// -----
+
+func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> {
+  %res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
+  return %res : vector<4x3x2xf32>
+}
+
+// CHECK-LABEL: func.func @negative_transpose_with_no_unit_dims
+// CHECK-NOT: vector.shape_cast

>From ecc937581adc20fa6937a78693bad97ba91a58a8 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 7 Aug 2024 15:47:27 +0000
Subject: [PATCH 3/3] Fixups

---
 mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h    | 2 +-
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 67c36bfa06ded..5f32aca88a273 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -120,7 +120,7 @@ inline auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
   };
 }
 
-/// Returns an iterator over the dims (inc scalability) of a VectorType.
+/// Returns a range over the dims (size and scalability) of a VectorType.
 inline auto getDims(VectorType vType) {
   return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6b39cef7899d9..55c1c6bad9f2a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1771,7 +1771,7 @@ struct DropUnitDimsFromTransposeOp final
       newPerm.push_back(idx - droppedDimsBefore[idx]);
     }
 
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
     // Drop the unit dims via shape_cast.
     auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
         loc, sourceTypeWithoutUnitDims, op.getVector());



More information about the Mlir-commits mailing list