[Mlir-commits] [mlir] [mlir][vector] Make `TransposeOpLowering` configurable (PR #73915)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Nov 30 01:37:55 PST 2023


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/73915

Following the discussion here:
  * https://github.com/llvm/llvm-project/pull/72105
this patch makes the `TransposeOpLowering` configurable so that one can
select whether to favour `vector.shape_cast` over `vector.transpose`.

As per the discussion in #72105, using `vector.shape_cast` is very
beneficial and desirable when targeting `LLVM IR` (CPU lowering), but
simply won't work when targeting `SPIR-V` (GPU lowering). So we need a
mechanism to be able to disable/enable the pattern introduced in #72105.
This patch proposes one such mechanism.

While this should solve the problem that we are facing today, we may
need to introduce something more elaborate to specialise for CPU vs GPU
lowering. Also, (once implemented) this proposal might make this
workaround redundant:
  * https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/


>From 5b2d44f7a137e8323904c70448dc52379090a562 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 30 Nov 2023 09:29:36 +0000
Subject: [PATCH] [mlir][vector] Make `TransposeOpLowering` configurable

Following the discussion here:
  * https://github.com/llvm/llvm-project/pull/72105
this patch makes the `TransposeOpLowering` configurable so that one can
select whether to favour `vector.shape_cast` over `vector.transpose`.

As per the discussion in #72105, using `vector.shape_cast` is very
beneficial and desirable when targeting `LLVM IR` (CPU lowering), but
simply won't work when targeting `SPIR-V` (GPU lowering). So we need a
mechanism to be able to disable/enable the pattern introduced in #72105.
This patch proposes one such mechanism.

While this should solve the problem that we are facing today, we may
need to introduce something more elaborate to specialise for CPU vs GPU
lowering. Also, (once implemented) this proposal might make this
workaround redundant:
  * https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/
---
 .../Vector/Transforms/VectorTransforms.h      |  2 ++
 .../Transforms/LowerVectorTranspose.cpp       | 34 ++++++++++---------
 2 files changed, 20 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
index 08d3bb157a0e396..8f300d66c9a18f2 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
@@ -59,6 +59,8 @@ struct VectorTransformsOptions {
     vectorTransferSplit = opt;
     return *this;
   }
+
+  bool useShapeCast = true;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 97f6caca1b25ccc..4d43a76c4a4efcc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -334,22 +334,24 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       return rewriter.notifyMatchFailure(
           op, "Options specifies lowering to shuffle");
 
-    // Replace:
-    //   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
-    //                                 vector<1xnxelty>
-    // with:
-    //   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
-    //
-    // Source with leading unit dim (inverse) is also replaced. Unit dim must
-    // be fixed. Non-unit can be scalable.
-    if (resType.getRank() == 2 &&
-        ((resType.getShape().front() == 1 &&
-          !resType.getScalableDims().front()) ||
-         (resType.getShape().back() == 1 &&
-          !resType.getScalableDims().back())) &&
-        transp == ArrayRef<int64_t>({1, 0})) {
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
-      return success();
+    if (vectorTransformOptions.useShapeCast) {
+      // Replace:
+      //   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
+      //                                 vector<1xnxelty>
+      // with:
+      //   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
+      //
+      // Source with leading unit dim (inverse) is also replaced. Unit dim must
+      // be fixed. Non-unit can be scalable.
+      if (resType.getRank() == 2 &&
+          ((resType.getShape().front() == 1 &&
+            !resType.getScalableDims().front()) ||
+           (resType.getShape().back() == 1 &&
+            !resType.getScalableDims().back())) &&
+          transp == ArrayRef<int64_t>({1, 0})) {
+        rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+        return success();
+      }
     }
 
     if (inputType.isScalable())



More information about the Mlir-commits mailing list