[Mlir-commits] [mlir] [mlir][tensor] Fold rank increasing expand_shape into insert_slice (PR #93018)
Matthias Springer
llvmlistbot at llvm.org
Thu May 23 05:01:10 PDT 2024
================
@@ -79,12 +79,42 @@ struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
return success();
}
};
+
+/// Fold rank increasing expand_shape into insert_slice.
+template <typename OpTy>
+struct FoldRankIncreasingExpandIntoInsert : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy insertSliceOp,
+ PatternRewriter &rewriter) const override {
+ auto expandShapeOp = insertSliceOp.getSource()
+ .template getDefiningOp<tensor::ExpandShapeOp>();
+ if (!expandShapeOp)
+ return failure();
+
+ // Only fold away simple rank increasing expansion.
+ SliceVerificationResult res = isRankReducedType(
+ expandShapeOp.getResultType(), expandShapeOp.getSrcType());
+ if (res != SliceVerificationResult::Success) {
----------------
matthias-springer wrote:
nit: trivial braces not needed
https://github.com/llvm/llvm-project/pull/93018
More information about the Mlir-commits
mailing list