[Mlir-commits] [mlir] [mlir][vector] Move transpose with unit-dim to shape_cast pattern (PR #72493)

Diego Caballero llvmlistbot at llvm.org
Thu Nov 16 11:44:34 PST 2023


================
@@ -5564,12 +5564,51 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose with non-scalable unit dims into a shape_cast.
+///
+/// Replace:
+///   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
+///                                 vector<1xnxelty>
+/// with:
+///   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
+///
+/// Source with leading unit dim (inverse) is also replaced. Unit dim must
+/// be fixed. Non-unit dims can be scalable.
+class FoldTransposeWithNonScalableUnitDimsToShapeCast final
+    : public OpRewritePattern<TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TransposeOp transpOp,
+                                PatternRewriter &rewriter) const override {
+    Value input = transpOp.getVector();
+    VectorType resType = transpOp.getResultVectorType();
+
+    SmallVector<int64_t> permutation;
+    transpOp.getTransp(permutation);
+
+    if (resType.getRank() == 2 &&
+        ((resType.getShape().front() == 1 &&
+          !resType.getScalableDims().front()) ||
+         (resType.getShape().back() == 1 &&
+          !resType.getScalableDims().back())) &&
+        permutation == ArrayRef<int64_t>({1, 0})) {
----------------
dcaballe wrote:

As a follow-up patch, I wonder if we could generalize this to n-D dimensions where 0 or 1 of them is != 1? If I'm not missing something, the permutation pattern itself shouldn't even matter for those cases?

https://github.com/llvm/llvm-project/pull/72493


More information about the Mlir-commits mailing list