[Mlir-commits] [mlir] [mlir][Vector] Add fold transpose(shape_cast) -> shape_cast (PR #73951)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Nov 30 07:55:27 PST 2023
================
@@ -5548,12 +5548,55 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
+/// Folds transpose(shape_cast) into a new shape_cast, when the transpose just
+/// permutes a unit dim from the result of the shape_cast.
+class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TransposeOp transpOp,
+ PatternRewriter &rewriter) const override {
+ Value transposeSrc = transpOp.getVector();
+ auto shapeCastOp = transposeSrc.getDefiningOp<vector::ShapeCastOp>();
+ if (!shapeCastOp)
+ return failure();
+
+ auto sourceType = transpOp.getSourceVectorType();
----------------
MacDue wrote:
It's still a legal `shape_cast` and will lower to the (pretty much) the same thing. But yeah, the point here is we're not adding a `shape_cast` where there was not already one before, so this should not cause problems for SPIR-V :)
https://github.com/llvm/llvm-project/pull/73951
More information about the Mlir-commits
mailing list