[Mlir-commits] [mlir] fb8f492 - [mlir][sparse] clone a empty sparse tensor when fuse convert into pro… (#92158)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 14 13:26:55 PDT 2024


Author: Peiming Liu
Date: 2024-05-14T13:26:49-07:00
New Revision: fb8f492a1ccb2236a82701c76f82960fd6cdb725

URL: https://github.com/llvm/llvm-project/commit/fb8f492a1ccb2236a82701c76f82960fd6cdb725
DIFF: https://github.com/llvm/llvm-project/commit/fb8f492a1ccb2236a82701c76f82960fd6cdb725.diff

LOG: [mlir][sparse] clone a empty sparse tensor when fuse convert into pro… (#92158)

…ducer.

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index da635c2578885..5fb009e3eebe6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -302,17 +302,17 @@ struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
         !producer.getResult(0).hasOneUse()) {
       return failure();
     }
+    // Clone the materialization operation, but update the result to sparse.
+    rewriter.setInsertionPoint(producer);
+    Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
+    Operation *cloned = rewriter.clone(*init);
+    cloned->getResult(0).setType(op.getResult().getType());
+
     rewriter.modifyOpInPlace(producer, [&]() {
+      producer.getDpsInitsMutable().assign(cloned->getResults());
       producer.getResult(0).setType(op.getResult().getType());
     });
 
-    Operation *materializeOp =
-        producer.getDpsInitOperand(0)->get().getDefiningOp();
-
-    rewriter.modifyOpInPlace(materializeOp, [&]() {
-      materializeOp->getResult(0).setType(op.getResult().getType());
-    });
-
     rewriter.replaceAllOpUsesWith(op, producer);
     op->erase();
 

diff  --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
index efa92e565ba57..4e4d2c27b0966 100644
--- a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
@@ -54,6 +54,50 @@ func.func @fold_convert(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x
   return %2 : tensor<128x32x32x1xf32, #CCCD>
 }
 
+#trait_bin = {
+  indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+  ],
+  iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+}
+
+// CHECK-FOLD-LABEL:   func.func @fold_convert_multi_use(
+// CHECK-FOLD:           tensor.empty() : tensor<128x32x32x1xf32>
+// CHECK-FOLD:           linalg.generic
+// CHECK-FOLD:           tensor.empty() : tensor<128x32x32x1xf32, #sparse>
+// CHECK-FOLD:           linalg.generic
+// CHECK-FOLD-NOT:       sparse_tensor.convert
+func.func @fold_convert_multi_use(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>,
+                        %arg2: tensor<128x32x32x1xf32>, %arg3: tensor<128x32x32x1xf32>) -> (tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32, #CCCD>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %cst_0 = arith.constant 1.000000e+00 : f32
+  %cst_1 = arith.constant 1.000000e+00 : f32
+
+  %0 = tensor.empty() : tensor<128x32x32x1xf32>
+  %1 = linalg.generic #trait_bin
+  ins(%arg0, %arg1 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
+  outs(%0 : tensor<128x32x32x1xf32>) {
+    ^bb0(%in: f32, %in_1: f32, %out: f32):
+      %3 = arith.mulf %in, %in_1 : f32
+      linalg.yield %3 : f32
+    } -> tensor<128x32x32x1xf32>
+
+  // A second kernel that uses %0 as the init operand.
+  %3 = linalg.generic #trait_bin
+  ins(%arg2, %arg3 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
+  outs(%0 : tensor<128x32x32x1xf32>) {
+    ^bb0(%in: f32, %in_1: f32, %out: f32):
+      %3 = arith.mulf %in, %in_1 : f32
+      linalg.yield %3 : f32
+    } -> tensor<128x32x32x1xf32>
+  %4 = sparse_tensor.convert %3 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #CCCD>
+
+  return %1, %4 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32, #CCCD>
+}
+
+
 
 // FIXME: The following kernel is not sparsifiable because `arith.select`
 // operations is not handled by the sparse compiler at the moment.


        


More information about the Mlir-commits mailing list