[Mlir-commits] [mlir] [mlir][vector] Move transpose with unit-dim to shape_cast pattern (PR #72493)
Diego Caballero
llvmlistbot at llvm.org
Thu Nov 16 11:44:34 PST 2023
================
@@ -5564,12 +5564,51 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
+/// Folds transpose with non-scalable unit dims into a shape_cast.
+///
+/// Replace:
+/// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
+/// vector<1xnxelty>
+/// with:
+/// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
+///
+/// Source with leading unit dim (inverse) is also replaced. Unit dim must
+/// be fixed. Non-unit dims can be scalable.
+class FoldTransposeWithNonScalableUnitDimsToShapeCast final
+ : public OpRewritePattern<TransposeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TransposeOp transpOp,
+ PatternRewriter &rewriter) const override {
+ Value input = transpOp.getVector();
+ VectorType resType = transpOp.getResultVectorType();
+
+ SmallVector<int64_t> permutation;
+ transpOp.getTransp(permutation);
+
+ if (resType.getRank() == 2 &&
+ ((resType.getShape().front() == 1 &&
+ !resType.getScalableDims().front()) ||
+ (resType.getShape().back() == 1 &&
+ !resType.getScalableDims().back())) &&
+ permutation == ArrayRef<int64_t>({1, 0})) {
----------------
dcaballe wrote:
As a follow-up patch, I wonder if we could generalize this to n-D dimensions where 0 or 1 of them is != 1? If I'm not missing something, the permutation pattern itself shouldn't even matter for those cases?
https://github.com/llvm/llvm-project/pull/72493
More information about the Mlir-commits
mailing list