[Mlir-commits] [mlir] [mlir][vector] Make `TransposeOpLowering` configurable (PR #73915)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 30 01:38:29 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

Following the discussion here:
  * https://github.com/llvm/llvm-project/pull/72105
this patch makes the `TransposeOpLowering` configurable so that one can
select whether to favour `vector.shape_cast` over `vector.transpose`.

As per the discussion in #<!-- -->72105, using `vector.shape_cast` is very
beneficial and desirable when targeting `LLVM IR` (CPU lowering), but
simply won't work when targeting `SPIR-V` (GPU lowering). So we need a
mechanism to be able to disable/enable the pattern introduced in #<!-- -->72105.
This patch proposes one such mechanism.

While this should solve the problem that we are facing today, we may
need to introduce something more elaborate to specialise for CPU vs GPU
lowering. Also, (once implemented) this proposal might make this
workaround redundant:
  * https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/


---
Full diff: https://github.com/llvm/llvm-project/pull/73915.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h (+2) 
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+18-16) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
index 08d3bb157a0e396..8f300d66c9a18f2 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
@@ -59,6 +59,8 @@ struct VectorTransformsOptions {
     vectorTransferSplit = opt;
     return *this;
   }
+
+  bool useShapeCast = true;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 97f6caca1b25ccc..4d43a76c4a4efcc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -334,22 +334,24 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       return rewriter.notifyMatchFailure(
           op, "Options specifies lowering to shuffle");
 
-    // Replace:
-    //   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
-    //                                 vector<1xnxelty>
-    // with:
-    //   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
-    //
-    // Source with leading unit dim (inverse) is also replaced. Unit dim must
-    // be fixed. Non-unit can be scalable.
-    if (resType.getRank() == 2 &&
-        ((resType.getShape().front() == 1 &&
-          !resType.getScalableDims().front()) ||
-         (resType.getShape().back() == 1 &&
-          !resType.getScalableDims().back())) &&
-        transp == ArrayRef<int64_t>({1, 0})) {
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
-      return success();
+    if (vectorTransformOptions.useShapeCast) {
+      // Replace:
+      //   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
+      //                                 vector<1xnxelty>
+      // with:
+      //   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
+      //
+      // Source with leading unit dim (inverse) is also replaced. Unit dim must
+      // be fixed. Non-unit can be scalable.
+      if (resType.getRank() == 2 &&
+          ((resType.getShape().front() == 1 &&
+            !resType.getScalableDims().front()) ||
+           (resType.getShape().back() == 1 &&
+            !resType.getScalableDims().back())) &&
+          transp == ArrayRef<int64_t>({1, 0})) {
+        rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+        return success();
+      }
     }
 
     if (inputType.isScalable())

``````````

</details>


https://github.com/llvm/llvm-project/pull/73915


More information about the Mlir-commits mailing list