[Mlir-commits] [mlir] 8850728 - [mlir][Vector] Add a pattern to lower 2-D vector.transpose to shape_cast+shuffle.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Nov 2 15:12:50 PDT 2021


Author: Nicolas Vasilache
Date: 2021-11-02T22:12:46Z
New Revision: 885072820c4ededc5ffa570ba8fd89b001fb9f11

URL: https://github.com/llvm/llvm-project/commit/885072820c4ededc5ffa570ba8fd89b001fb9f11
DIFF: https://github.com/llvm/llvm-project/commit/885072820c4ededc5ffa570ba8fd89b001fb9f11.diff

LOG: [mlir][Vector] Add a pattern to lower 2-D vector.transpose to shape_cast+shuffle.

The 2-D case can be rewritten to generate quite fewer instructions and a single vector.shuffle which seems to provide a nice performance boost.
Add this arrow to our quiver by exposing it with a new vector transform option.

Differential Revision: https://reviews.llvm.org/D113062

Added: 
    mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
index 587f334bc0473..433ab8df0571e 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
@@ -29,6 +29,8 @@ enum class VectorTransposeLowering {
   /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
   /// intrinsics.
   Flat = 1,
+  /// Lower 2-D transpose to `vector.shuffle`.
+  Shuffle = 2,
 };
 /// Enum to control the lowering of `vector.multi_reduction` operations.
 enum class VectorMultiReductionLowering {

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index c6f29e0a641a2..912031035ed7e 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -686,6 +686,12 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     for (auto attr : op.transp())
       transp.push_back(attr.cast<IntegerAttr>().getInt());
 
+    if (vectorTransformOptions.vectorTransposeLowering ==
+            vector::VectorTransposeLowering::Shuffle &&
+        resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
+      return rewriter.notifyMatchFailure(
+          op, "Options specifies lowering to shuffle");
+
     // Handle a true 2-D matrix transpose 
diff erently when requested.
     if (vectorTransformOptions.vectorTransposeLowering ==
             vector::VectorTransposeLowering::Flat &&
@@ -740,6 +746,61 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
   vector::VectorTransformsOptions vectorTransformOptions;
 };
 
+/// Rewrite a 2-D vector.transpose as a sequence of:
+///   vector.shape_cast 2D -> 1D
+///   vector.shuffle
+///   vector.shape_cast 1D -> 2D
+class TransposeOp2DToShuffleLowering
+    : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  TransposeOp2DToShuffleLowering(
+      vector::VectorTransformsOptions vectorTransformOptions,
+      MLIRContext *context)
+      : OpRewritePattern<vector::TransposeOp>(context),
+        vectorTransformOptions(vectorTransformOptions) {}
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+
+    VectorType srcType = op.getVectorType();
+    if (srcType.getRank() != 2)
+      return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
+
+    SmallVector<int64_t, 4> transp;
+    for (auto attr : op.transp())
+      transp.push_back(attr.cast<IntegerAttr>().getInt());
+    if (transp[0] != 1 && transp[1] != 0)
+      return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
+
+    if (vectorTransformOptions.vectorTransposeLowering !=
+        VectorTransposeLowering::Shuffle)
+      return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
+
+    int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
+    Value casted = rewriter.create<vector::ShapeCastOp>(
+        loc, VectorType::get({m * n}, srcType.getElementType()), op.vector());
+    SmallVector<int64_t> mask;
+    mask.reserve(m * n);
+    for (int64_t j = 0; j < n; ++j)
+      for (int64_t i = 0; i < m; ++i)
+        mask.push_back(i * n + j);
+
+    Value shuffled =
+        rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
+                                                     shuffled);
+
+    return success();
+  }
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformOptions;
+};
+
 /// Progressive lowering of OuterProductOp.
 /// One:
 ///   %x = vector.outerproduct %lhs, %rhs, %acc
@@ -3656,7 +3717,8 @@ void mlir::vector::populateVectorContractLoweringPatterns(
 
 void mlir::vector::populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns, VectorTransformsOptions options) {
-  patterns.add<TransposeOpLowering>(options, patterns.getContext());
+  patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
+      options, patterns.getContext());
 }
 
 void mlir::vector::populateVectorReductionToContractPatterns(

diff  --git a/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir
new file mode 100644
index 0000000000000..1b65579b5c813
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-shuffle-transpose=1 | FileCheck %s
+
+// CHECK-LABEL: func @transpose
+func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
+  // CHECK: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32>
+  //            0 4
+  // 0 1 2 3    1 5
+  // 4 5 6 7 -> 2 6
+  //            3 7
+  // CHECK: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32>
+  // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+  return %0 : vector<4x2xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index e7d520bcdb173..12d57489af60b 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -116,6 +116,10 @@ struct TestVectorContractionConversion
       *this, "vector-flat-transpose",
       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
       llvm::cl::init(false)};
+  Option<bool> lowerToShuffleTranspose{
+      *this, "vector-shuffle-transpose",
+      llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
+      llvm::cl::init(false)};
   Option<bool> lowerToOuterProduct{
       *this, "vector-outerproduct",
       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
@@ -165,12 +169,15 @@ struct TestVectorContractionConversion
         VectorTransposeLowering::EltWise;
     if (lowerToFlatTranspose)
       transposeLowering = VectorTransposeLowering::Flat;
+    if (lowerToShuffleTranspose)
+      transposeLowering = VectorTransposeLowering::Shuffle;
     VectorTransformsOptions options{
         contractLowering, vectorMultiReductionLowering, transposeLowering};
     populateVectorBroadcastLoweringPatterns(patterns);
     populateVectorContractLoweringPatterns(patterns, options);
     populateVectorMaskOpLoweringPatterns(patterns);
-    populateVectorShapeCastLoweringPatterns(patterns);
+    if (!lowerToShuffleTranspose)
+      populateVectorShapeCastLoweringPatterns(patterns);
     populateVectorTransposeLoweringPatterns(patterns, options);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }


        


More information about the Mlir-commits mailing list