[Mlir-commits] [mlir] [mlir][vector] Split `TransposeOpLowering` into 2 patterns (PR #91935)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue May 14 02:01:04 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/91935

>From c9709bbe09f5e627ee7b4643e7778c3265147e3d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 13 May 2024 08:32:06 +0000
Subject: [PATCH 1/2] [mlir][vector] Split `TransposeOpLowering` into 2
 patterns

Splits `TransposeOpLowering` into two patterns:
  1. `Transpose2DWithUnitDimToShapeCast` - rewrites 2D `vector.transpose`
    as `vector.shape_cast` (there has to be at least one unit dim),
  2. `TransposeOpLowering` - the original pattern without the part
    extracted into `Transpose2DWithUnitDimToShapeCast`.

The rationale behind the split:
  * the output generated by `Transpose2DWithUnitDimToShapeCast` doesn't
    really match the intended output from `TransposeOpLowering` as
    documented in the source file - it doesn't make much sense to keep
    it embedded inside `TransposeOpLowering`,
  * `Transpose2DWithUnitDimToShapeCast` _does_ work for scalable vectors,
    `TransposeOpLowering` _does_ not.
---
 .../Transforms/LowerVectorTranspose.cpp       | 70 +++++++++++++------
 1 file changed, 48 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 7011c478fefba..0706f22cb8b12 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -326,6 +326,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     VectorType inputType = op.getSourceVectorType();
     VectorType resType = op.getResultVectorType();
 
+    if (inputType.isScalable())
+      return rewriter.notifyMatchFailure(
+          op, "This lowering does not support scalable vectors");
+
     // Set up convenience transposition table.
     ArrayRef<int64_t> transp = op.getPermutation();
 
@@ -334,28 +338,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       return rewriter.notifyMatchFailure(
           op, "Options specifies lowering to shuffle");
 
-    // Replace:
-    //   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
-    //                                 vector<1xnxelty>
-    // with:
-    //   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
-    //
-    // Source with leading unit dim (inverse) is also replaced. Unit dim must
-    // be fixed. Non-unit can be scalable.
-    if (resType.getRank() == 2 &&
-        ((resType.getShape().front() == 1 &&
-          !resType.getScalableDims().front()) ||
-         (resType.getShape().back() == 1 &&
-          !resType.getScalableDims().back())) &&
-        transp == ArrayRef<int64_t>({1, 0})) {
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
-      return success();
-    }
-
-    // TODO: Add support for scalable vectors
-    if (inputType.isScalable())
-      return failure();
-
     // Handle a true 2-D matrix transpose differently when requested.
     if (vectorTransformOptions.vectorTransposeLowering ==
             vector::VectorTransposeLowering::Flat &&
@@ -411,6 +393,48 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
   vector::VectorTransformsOptions vectorTransformOptions;
 };
 
+/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
+/// to 2D vectors with at least one unit dim. For example:
+///
+/// Replace:
+///   vector.transpose %0, [1, 0] : vector<4x1xi32>> to
+///                                 vector<1x4xi32>
+/// with:
+///   vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
+///
+/// Source with leading unit dim (inverse) is also replaced. Unit dim must
+/// be fixed. Non-unit can be scalable.
+class Transpose2DWithUnitDimToShapeCast
+    : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
+                                    PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.getVector();
+    VectorType resType = op.getResultVectorType();
+
+    // Set up convenience transposition table.
+    ArrayRef<int64_t> transp = op.getPermutation();
+
+    if (resType.getRank() == 2 &&
+        ((resType.getShape().front() == 1 &&
+          !resType.getScalableDims().front()) ||
+         (resType.getShape().back() == 1 &&
+          !resType.getScalableDims().back())) &&
+        transp == ArrayRef<int64_t>({1, 0})) {
+      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+      return success();
+    }
+
+    return failure();
+  }
+};
+
 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
 /// If the strategy is Shuffle1D, it will be lowered to:
 ///   vector.shape_cast 2D -> 1D
@@ -483,6 +507,8 @@ class TransposeOp2DToShuffleLowering
 void mlir::vector::populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns, VectorTransformsOptions options,
     PatternBenefit benefit) {
+  patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
+                                                  benefit);
   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
       options, patterns.getContext(), benefit);
 }

>From 5dc478eb9845a78b11493f51285f958e1e992818 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 14 May 2024 09:00:54 +0000
Subject: [PATCH 2/2] fixup! [mlir][vector] Split `TransposeOpLowering` into 2
 patterns

Add a TODO
---
 .../Vector/Transforms/LowerVectorTranspose.cpp | 18 +++++++++++++++++-
 1 file changed, 17 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 0706f22cb8b12..ca8a6f6d82a6e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -403,7 +403,23 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 ///   vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
 ///
 /// Source with leading unit dim (inverse) is also replaced. Unit dim must
-/// be fixed. Non-unit can be scalable.
+/// be fixed. Non-unit dim can be scalable.
+///
+/// TODO: This pattern was introduced specifically to help lower scalable
+/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
+/// to cancel out) would be preferable:
+///
+///  BEFORE:
+///     %0 = some_op
+///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
+///     %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+///  AFTER:
+///     %0 = some_op
+///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
+///
+/// Given the context above, we may want to consider (re-)moving this pattern
+/// at some later time. I am leaving it for now in case there are other users
+/// that I am not aware of.
 class Transpose2DWithUnitDimToShapeCast
     : public OpRewritePattern<vector::TransposeOp> {
 public:



More information about the Mlir-commits mailing list