[Mlir-commits] [mlir] [mlir][sparse] clone a empty sparse tensor when fuse convert into pro… (PR #92158)
Peiming Liu
llvmlistbot at llvm.org
Tue May 14 11:55:41 PDT 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/92158
>From 0ba059e06352d2cb2b68a592379df6cfe495316a Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 14 May 2024 18:23:57 +0000
Subject: [PATCH 1/2] [mlir][sparse] clone a empty sparse tensor when fuse
convert into producer.
---
.../Transforms/SparseTensorRewriting.cpp | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index da635c2578885..5fb009e3eebe6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -302,17 +302,17 @@ struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
!producer.getResult(0).hasOneUse()) {
return failure();
}
+ // Clone the materialization operation, but update the result to sparse.
+ rewriter.setInsertionPoint(producer);
+ Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
+ Operation *cloned = rewriter.clone(*init);
+ cloned->getResult(0).setType(op.getResult().getType());
+
rewriter.modifyOpInPlace(producer, [&]() {
+ producer.getDpsInitsMutable().assign(cloned->getResults());
producer.getResult(0).setType(op.getResult().getType());
});
- Operation *materializeOp =
- producer.getDpsInitOperand(0)->get().getDefiningOp();
-
- rewriter.modifyOpInPlace(materializeOp, [&]() {
- materializeOp->getResult(0).setType(op.getResult().getType());
- });
-
rewriter.replaceAllOpUsesWith(op, producer);
op->erase();
>From 73ac45f4019d1ea5b7bb64f3a22dd8c73d144374 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 14 May 2024 18:55:28 +0000
Subject: [PATCH 2/2] add test
---
.../fuse_sparse_convert_into_producer.mlir | 44 +++++++++++++++++++
1 file changed, 44 insertions(+)
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
index efa92e565ba57..4e4d2c27b0966 100644
--- a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
@@ -54,6 +54,50 @@ func.func @fold_convert(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x
return %2 : tensor<128x32x32x1xf32, #CCCD>
}
+#trait_bin = {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+}
+
+// CHECK-FOLD-LABEL: func.func @fold_convert_multi_use(
+// CHECK-FOLD: tensor.empty() : tensor<128x32x32x1xf32>
+// CHECK-FOLD: linalg.generic
+// CHECK-FOLD: tensor.empty() : tensor<128x32x32x1xf32, #sparse>
+// CHECK-FOLD: linalg.generic
+// CHECK-FOLD-NOT: sparse_tensor.convert
+func.func @fold_convert_multi_use(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>,
+ %arg2: tensor<128x32x32x1xf32>, %arg3: tensor<128x32x32x1xf32>) -> (tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32, #CCCD>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %cst_0 = arith.constant 1.000000e+00 : f32
+ %cst_1 = arith.constant 1.000000e+00 : f32
+
+ %0 = tensor.empty() : tensor<128x32x32x1xf32>
+ %1 = linalg.generic #trait_bin
+ ins(%arg0, %arg1 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
+ outs(%0 : tensor<128x32x32x1xf32>) {
+ ^bb0(%in: f32, %in_1: f32, %out: f32):
+ %3 = arith.mulf %in, %in_1 : f32
+ linalg.yield %3 : f32
+ } -> tensor<128x32x32x1xf32>
+
+ // A second kernel that uses %0 as the init operand.
+ %3 = linalg.generic #trait_bin
+ ins(%arg2, %arg3 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
+ outs(%0 : tensor<128x32x32x1xf32>) {
+ ^bb0(%in: f32, %in_1: f32, %out: f32):
+ %3 = arith.mulf %in, %in_1 : f32
+ linalg.yield %3 : f32
+ } -> tensor<128x32x32x1xf32>
+ %4 = sparse_tensor.convert %3 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #CCCD>
+
+ return %1, %4 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32, #CCCD>
+}
+
+
// FIXME: The following kernel is not sparsifiable because `arith.select`
// operations is not handled by the sparse compiler at the moment.
More information about the Mlir-commits
mailing list