[Mlir-commits] [mlir] [mlir][tensor] Fold rank increasing expand_shape into insert_slice (PR #93018)

Matthias Springer llvmlistbot at llvm.org
Thu May 23 05:01:12 PDT 2024


================
@@ -79,12 +79,42 @@ struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
     return success();
   }
 };
+
+/// Fold rank increasing expand_shape into insert_slice.
+template <typename OpTy>
+struct FoldRankIncreasingExpandIntoInsert : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy insertSliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto expandShapeOp = insertSliceOp.getSource()
+                             .template getDefiningOp<tensor::ExpandShapeOp>();
+    if (!expandShapeOp)
+      return failure();
+
+    // Only fold away simple rank increasing expansion.
+    SliceVerificationResult res = isRankReducedType(
+        expandShapeOp.getResultType(), expandShapeOp.getSrcType());
+    if (res != SliceVerificationResult::Success) {
+      return rewriter.notifyMatchFailure(insertSliceOp,
+                                         "expected rank increasing expansion");
+    }
+
+    rewriter.modifyOpInPlace(insertSliceOp, [&]() {
+      insertSliceOp.setOperand(/*source=*/0, expandShapeOp.getSrc());
----------------
matthias-springer wrote:

`insertSliceOp.getSourceMutable().assign(...)`


https://github.com/llvm/llvm-project/pull/93018


More information about the Mlir-commits mailing list