[Mlir-commits] [mlir] [mlir][Vector] Add fold transpose(shape_cast) -> shape_cast (PR #73951)

Quinn Dawkins llvmlistbot at llvm.org
Thu Nov 30 07:28:11 PST 2023


================
@@ -5548,12 +5548,55 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(shape_cast) into a new shape_cast, when the transpose just
+/// permutes a unit dim from the result of the shape_cast.
+class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TransposeOp transpOp,
+                                PatternRewriter &rewriter) const override {
+    Value transposeSrc = transpOp.getVector();
+    auto shapeCastOp = transposeSrc.getDefiningOp<vector::ShapeCastOp>();
+    if (!shapeCastOp)
+      return failure();
+
+    auto sourceType = transpOp.getSourceVectorType();
----------------
qedawkins wrote:

If we had
```
%0 = shape_cast ... : vector<4x1x4> to vector<2x2x1x4>
transpose %0 [0, 1, 3, 2] : vector<2x2x1x4> to vector<2x2x4x1>
```
This devolves to the same discussion in the other PR. Since there's already a shape_cast in the source I won't block here, but would it still work to use the source vector type of the `shape_cast`?

https://github.com/llvm/llvm-project/pull/73951


More information about the Mlir-commits mailing list