[Mlir-commits] [mlir] cbd72cb - [mlir][vector] Split `TransposeOpLowering` into 2 patterns (#91935)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 14 04:09:01 PDT 2024
Author: Andrzej WarzyĆski
Date: 2024-05-14T12:08:57+01:00
New Revision: cbd72cb0deec31a5c3063cf1f1af759761115eee
URL: https://github.com/llvm/llvm-project/commit/cbd72cb0deec31a5c3063cf1f1af759761115eee
DIFF: https://github.com/llvm/llvm-project/commit/cbd72cb0deec31a5c3063cf1f1af759761115eee.diff
LOG: [mlir][vector] Split `TransposeOpLowering` into 2 patterns (#91935)
Splits `TransposeOpLowering` into two patterns:
1. `Transpose2DWithUnitDimToShapeCast` - rewrites 2D `vector.transpose`
as `vector.shape_cast` (there has to be at least one unit dim),
2. `TransposeOpLowering` - the original pattern without the part
extracted into `Transpose2DWithUnitDimToShapeCast`.
The rationale behind the split:
* the output generated by `Transpose2DWithUnitDimToShapeCast` doesn't
really match the intended output from `TransposeOpLowering` as
documented in the source file - it doesn't make much sense to keep
it embedded inside `TransposeOpLowering`,
* `Transpose2DWithUnitDimToShapeCast` _does_ work for scalable vectors,
`TransposeOpLowering` _does_ not.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 7011c478fefba..ca8a6f6d82a6e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -326,6 +326,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
VectorType inputType = op.getSourceVectorType();
VectorType resType = op.getResultVectorType();
+ if (inputType.isScalable())
+ return rewriter.notifyMatchFailure(
+ op, "This lowering does not support scalable vectors");
+
// Set up convenience transposition table.
ArrayRef<int64_t> transp = op.getPermutation();
@@ -334,28 +338,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");
- // Replace:
- // vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
- // vector<1xnxelty>
- // with:
- // vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
- //
- // Source with leading unit dim (inverse) is also replaced. Unit dim must
- // be fixed. Non-unit can be scalable.
- if (resType.getRank() == 2 &&
- ((resType.getShape().front() == 1 &&
- !resType.getScalableDims().front()) ||
- (resType.getShape().back() == 1 &&
- !resType.getScalableDims().back())) &&
- transp == ArrayRef<int64_t>({1, 0})) {
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
- return success();
- }
-
- // TODO: Add support for scalable vectors
- if (inputType.isScalable())
- return failure();
-
// Handle a true 2-D matrix transpose
diff erently when requested.
if (vectorTransformOptions.vectorTransposeLowering ==
vector::VectorTransposeLowering::Flat &&
@@ -411,6 +393,64 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
vector::VectorTransformsOptions vectorTransformOptions;
};
+/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
+/// to 2D vectors with at least one unit dim. For example:
+///
+/// Replace:
+/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
+/// vector<1x4xi32>
+/// with:
+/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
+///
+/// Source with leading unit dim (inverse) is also replaced. Unit dim must
+/// be fixed. Non-unit dim can be scalable.
+///
+/// TODO: This pattern was introduced specifically to help lower scalable
+/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
+/// to cancel out) would be preferable:
+///
+/// BEFORE:
+/// %0 = some_op
+/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
+/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+/// AFTER:
+/// %0 = some_op
+/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
+///
+/// Given the context above, we may want to consider (re-)moving this pattern
+/// at some later time. I am leaving it for now in case there are other users
+/// that I am not aware of.
+class Transpose2DWithUnitDimToShapeCast
+ : public OpRewritePattern<vector::TransposeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+ LogicalResult matchAndRewrite(vector::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ Value input = op.getVector();
+ VectorType resType = op.getResultVectorType();
+
+ // Set up convenience transposition table.
+ ArrayRef<int64_t> transp = op.getPermutation();
+
+ if (resType.getRank() == 2 &&
+ ((resType.getShape().front() == 1 &&
+ !resType.getScalableDims().front()) ||
+ (resType.getShape().back() == 1 &&
+ !resType.getScalableDims().back())) &&
+ transp == ArrayRef<int64_t>({1, 0})) {
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+ return success();
+ }
+
+ return failure();
+ }
+};
+
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
/// If the strategy is Shuffle1D, it will be lowered to:
/// vector.shape_cast 2D -> 1D
@@ -483,6 +523,8 @@ class TransposeOp2DToShuffleLowering
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
PatternBenefit benefit) {
+ patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
+ benefit);
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
options, patterns.getContext(), benefit);
}
More information about the Mlir-commits
mailing list