[Mlir-commits] [mlir] 8850728 - [mlir][Vector] Add a pattern to lower 2-D vector.transpose to shape_cast+shuffle.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Nov 2 15:12:50 PDT 2021
Author: Nicolas Vasilache
Date: 2021-11-02T22:12:46Z
New Revision: 885072820c4ededc5ffa570ba8fd89b001fb9f11
URL: https://github.com/llvm/llvm-project/commit/885072820c4ededc5ffa570ba8fd89b001fb9f11
DIFF: https://github.com/llvm/llvm-project/commit/885072820c4ededc5ffa570ba8fd89b001fb9f11.diff
LOG: [mlir][Vector] Add a pattern to lower 2-D vector.transpose to shape_cast+shuffle.
The 2-D case can be rewritten to generate quite fewer instructions and a single vector.shuffle which seems to provide a nice performance boost.
Add this arrow to our quiver by exposing it with a new vector transform option.
Differential Revision: https://reviews.llvm.org/D113062
Added:
mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir
Modified:
mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
index 587f334bc0473..433ab8df0571e 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
@@ -29,6 +29,8 @@ enum class VectorTransposeLowering {
/// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
/// intrinsics.
Flat = 1,
+ /// Lower 2-D transpose to `vector.shuffle`.
+ Shuffle = 2,
};
/// Enum to control the lowering of `vector.multi_reduction` operations.
enum class VectorMultiReductionLowering {
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index c6f29e0a641a2..912031035ed7e 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -686,6 +686,12 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
for (auto attr : op.transp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
+ if (vectorTransformOptions.vectorTransposeLowering ==
+ vector::VectorTransposeLowering::Shuffle &&
+ resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
+ return rewriter.notifyMatchFailure(
+ op, "Options specifies lowering to shuffle");
+
// Handle a true 2-D matrix transpose
diff erently when requested.
if (vectorTransformOptions.vectorTransposeLowering ==
vector::VectorTransposeLowering::Flat &&
@@ -740,6 +746,61 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
vector::VectorTransformsOptions vectorTransformOptions;
};
+/// Rewrite a 2-D vector.transpose as a sequence of:
+/// vector.shape_cast 2D -> 1D
+/// vector.shuffle
+/// vector.shape_cast 1D -> 2D
+class TransposeOp2DToShuffleLowering
+ : public OpRewritePattern<vector::TransposeOp> {
+public:
+ using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+ TransposeOp2DToShuffleLowering(
+ vector::VectorTransformsOptions vectorTransformOptions,
+ MLIRContext *context)
+ : OpRewritePattern<vector::TransposeOp>(context),
+ vectorTransformOptions(vectorTransformOptions) {}
+
+ LogicalResult matchAndRewrite(vector::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+
+ VectorType srcType = op.getVectorType();
+ if (srcType.getRank() != 2)
+ return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
+
+ SmallVector<int64_t, 4> transp;
+ for (auto attr : op.transp())
+ transp.push_back(attr.cast<IntegerAttr>().getInt());
+ if (transp[0] != 1 && transp[1] != 0)
+ return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
+
+ if (vectorTransformOptions.vectorTransposeLowering !=
+ VectorTransposeLowering::Shuffle)
+ return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
+
+ int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
+ Value casted = rewriter.create<vector::ShapeCastOp>(
+ loc, VectorType::get({m * n}, srcType.getElementType()), op.vector());
+ SmallVector<int64_t> mask;
+ mask.reserve(m * n);
+ for (int64_t j = 0; j < n; ++j)
+ for (int64_t i = 0; i < m; ++i)
+ mask.push_back(i * n + j);
+
+ Value shuffled =
+ rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
+ shuffled);
+
+ return success();
+ }
+
+private:
+ /// Options to control the vector patterns.
+ vector::VectorTransformsOptions vectorTransformOptions;
+};
+
/// Progressive lowering of OuterProductOp.
/// One:
/// %x = vector.outerproduct %lhs, %rhs, %acc
@@ -3656,7 +3717,8 @@ void mlir::vector::populateVectorContractLoweringPatterns(
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options) {
- patterns.add<TransposeOpLowering>(options, patterns.getContext());
+ patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
+ options, patterns.getContext());
}
void mlir::vector::populateVectorReductionToContractPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir
new file mode 100644
index 0000000000000..1b65579b5c813
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-shuffle-transpose=1 | FileCheck %s
+
+// CHECK-LABEL: func @transpose
+func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
+ // CHECK: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32>
+ // 0 4
+ // 0 1 2 3 1 5
+ // 4 5 6 7 -> 2 6
+ // 3 7
+ // CHECK: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32>
+ // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+ return %0 : vector<4x2xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index e7d520bcdb173..12d57489af60b 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -116,6 +116,10 @@ struct TestVectorContractionConversion
*this, "vector-flat-transpose",
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
llvm::cl::init(false)};
+ Option<bool> lowerToShuffleTranspose{
+ *this, "vector-shuffle-transpose",
+ llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
+ llvm::cl::init(false)};
Option<bool> lowerToOuterProduct{
*this, "vector-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
@@ -165,12 +169,15 @@ struct TestVectorContractionConversion
VectorTransposeLowering::EltWise;
if (lowerToFlatTranspose)
transposeLowering = VectorTransposeLowering::Flat;
+ if (lowerToShuffleTranspose)
+ transposeLowering = VectorTransposeLowering::Shuffle;
VectorTransformsOptions options{
contractLowering, vectorMultiReductionLowering, transposeLowering};
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, options);
populateVectorMaskOpLoweringPatterns(patterns);
- populateVectorShapeCastLoweringPatterns(patterns);
+ if (!lowerToShuffleTranspose)
+ populateVectorShapeCastLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
More information about the Mlir-commits
mailing list