[Mlir-commits] [mlir] bbd2b08 - [mlir][vector] Make `TransposeOpLowering` configurable (#73915)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 4 08:56:47 PST 2023


Author: Andrzej WarzyƄski
Date: 2023-12-04T16:56:43Z
New Revision: bbd2b08b95fe76bea138c1b03c1cd42ed3ee04df

URL: https://github.com/llvm/llvm-project/commit/bbd2b08b95fe76bea138c1b03c1cd42ed3ee04df
DIFF: https://github.com/llvm/llvm-project/commit/bbd2b08b95fe76bea138c1b03c1cd42ed3ee04df.diff

LOG: [mlir][vector] Make `TransposeOpLowering` configurable (#73915)

Following the discussion here:

  * https://github.com/llvm/llvm-project/pull/72105

this patch makes the `TransposeOpLowering` configurable so that one can select
whether to favour `vector.shape_cast` over `vector.transpose`.

As per the discussion in #72105, using `vector.shape_cast` is very beneficial
and desirable when targeting `LLVM IR` (CPU lowering), but won't work when
targeting `SPIR-V` today (GPU lowering). Hence the need for a mechanism to be
able to disable/enable the pattern introduced in #72105. This patch proposes one
such mechanism.

While this should solve the problem that we are facing today, it's understood to
be a temporary workaround. It should be removed once support for lowering
`vector.shape_cast` to SPIR-V is added. Also, (once implemented) the following
proposal might make this workaround redundant:

  * https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
    mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
index 08d3bb157a0e3..41ffc92994602 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
@@ -59,6 +59,16 @@ struct VectorTransformsOptions {
     vectorTransferSplit = opt;
     return *this;
   }
+
+  /// Option to control if vector.transpose can lower to a vector.shape_cast.
+  /// TODO: ATM it's not possible to lower `vector.shape_cast` to SPIR-V
+  /// and hence the need for this opt-out. Once the missing support has been
+  /// added, this option can be removed.
+  bool useShapeCast = true;
+  VectorTransformsOptions &setUseShapeCast(bool opt = true) {
+    useShapeCast = opt;
+    return *this;
+  }
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 97f6caca1b25c..4d43a76c4a4ef 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -334,22 +334,24 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       return rewriter.notifyMatchFailure(
           op, "Options specifies lowering to shuffle");
 
-    // Replace:
-    //   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
-    //                                 vector<1xnxelty>
-    // with:
-    //   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
-    //
-    // Source with leading unit dim (inverse) is also replaced. Unit dim must
-    // be fixed. Non-unit can be scalable.
-    if (resType.getRank() == 2 &&
-        ((resType.getShape().front() == 1 &&
-          !resType.getScalableDims().front()) ||
-         (resType.getShape().back() == 1 &&
-          !resType.getScalableDims().back())) &&
-        transp == ArrayRef<int64_t>({1, 0})) {
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
-      return success();
+    if (vectorTransformOptions.useShapeCast) {
+      // Replace:
+      //   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
+      //                                 vector<1xnxelty>
+      // with:
+      //   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
+      //
+      // Source with leading unit dim (inverse) is also replaced. Unit dim must
+      // be fixed. Non-unit can be scalable.
+      if (resType.getRank() == 2 &&
+          ((resType.getShape().front() == 1 &&
+            !resType.getScalableDims().front()) ||
+           (resType.getShape().back() == 1 &&
+            !resType.getScalableDims().back())) &&
+          transp == ArrayRef<int64_t>({1, 0})) {
+        rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+        return success();
+      }
     }
 
     if (inputType.isScalable())


        


More information about the Mlir-commits mailing list