[Mlir-commits] [mlir] [mlir][sparse] clone a empty sparse tensor when fuse convert into pro… (PR #92158)

Peiming Liu llvmlistbot at llvm.org
Tue May 14 11:24:36 PDT 2024


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/92158

…ducer.

>From 0ba059e06352d2cb2b68a592379df6cfe495316a Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 14 May 2024 18:23:57 +0000
Subject: [PATCH] [mlir][sparse] clone a empty sparse tensor when fuse convert
 into producer.

---
 .../Transforms/SparseTensorRewriting.cpp           | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index da635c2578885..5fb009e3eebe6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -302,17 +302,17 @@ struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
         !producer.getResult(0).hasOneUse()) {
       return failure();
     }
+    // Clone the materialization operation, but update the result to sparse.
+    rewriter.setInsertionPoint(producer);
+    Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
+    Operation *cloned = rewriter.clone(*init);
+    cloned->getResult(0).setType(op.getResult().getType());
+
     rewriter.modifyOpInPlace(producer, [&]() {
+      producer.getDpsInitsMutable().assign(cloned->getResults());
       producer.getResult(0).setType(op.getResult().getType());
     });
 
-    Operation *materializeOp =
-        producer.getDpsInitOperand(0)->get().getDefiningOp();
-
-    rewriter.modifyOpInPlace(materializeOp, [&]() {
-      materializeOp->getResult(0).setType(op.getResult().getType());
-    });
-
     rewriter.replaceAllOpUsesWith(op, producer);
     op->erase();
 



More information about the Mlir-commits mailing list