[Mlir-commits] [mlir] [mlir][sparse] fold sparse convert into producer linalg op. (PR #89999)
Peiming Liu
llvmlistbot at llvm.org
Fri Apr 26 09:35:41 PDT 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/89999
>From a905cb262686a347549d3eb811359046812232e6 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 24 Apr 2024 22:10:27 +0000
Subject: [PATCH 1/2] [mlir][sparse] fold sparse convert into producer generic
operation.
---
.../Dialect/SparseTensor/IR/SparseTensor.h | 15 +++---
.../Transforms/SparseTensorRewriting.cpp | 38 +++++++++++++--
.../Transforms/Sparsification.cpp | 44 +++++++++++------
.../fuse_sparse_convert_into_producer.mlir | 48 +++++++++++++++++++
.../SparseTensor/no_fold_into_consumer.mlir | 2 -
5 files changed, 121 insertions(+), 26 deletions(-)
create mode 100644 mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 5e523ec428aefb..550e28813b4e9b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -90,17 +90,20 @@ inline MemRefType getMemRefType(T &&t) {
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
/// Returns true iff MLIR operand has any sparse operand.
-inline bool hasAnySparseOperand(Operation *op) {
- return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
- return getSparseTensorEncoding(t) != nullptr;
+inline bool hasAnySparseType(TypeRange types) {
+ return llvm::any_of(types, [](Type type) {
+ return getSparseTensorEncoding(type) != nullptr;
});
}
+/// Returns true iff MLIR operand has any sparse operand.
+inline bool hasAnySparseOperand(Operation *op) {
+ return hasAnySparseType(op->getOperands().getTypes());
+}
+
/// Returns true iff MLIR operand has any sparse result.
inline bool hasAnySparseResult(Operation *op) {
- return llvm::any_of(op->getResults().getTypes(), [](Type t) {
- return getSparseTensorEncoding(t) != nullptr;
- });
+ return hasAnySparseType(op->getResults().getTypes());
}
/// Returns true iff MLIR operand has any sparse operand or result.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5a39dfc6207707..641dcc61d7d09c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -289,6 +289,37 @@ struct FuseExtractSliceWithConcat
}
};
+/// Rewriting rule that converts direct yield of zero with initial allocation.
+struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConvertOp op,
+ PatternRewriter &rewriter) const override {
+ auto producer = op.getSource().getDefiningOp<GenericOp>();
+ if (!producer || producer.getDpsInits().size() != 1 ||
+ !isMaterializing(producer.getDpsInitOperand(0), false) ||
+ !producer.getResult(0).hasOneUse()) {
+ return failure();
+ }
+ rewriter.modifyOpInPlace(producer, [&]() {
+ producer.getResult(0).setType(op.getResult().getType());
+ });
+
+ Operation *materializeOp =
+ producer.getDpsInitOperand(0)->get().getDefiningOp();
+
+ rewriter.modifyOpInPlace(materializeOp, [&]() {
+ materializeOp->getResult(0).setType(op.getResult().getType());
+ });
+
+ rewriter.replaceAllOpUsesWith(op, producer);
+ op->erase();
+
+ return success();
+ }
+};
+
/// Rewriting rule that converts direct yield of zero with initial allocation.
struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
public:
@@ -1506,9 +1537,10 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
//===---------------------------------------------------------------------===//
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
- patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
- FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
- GenSemiRingSelect, PrintRewriter>(patterns.getContext());
+ patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
+ FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
+ GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
+ patterns.getContext());
}
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index cd046b670d9a8e..0a9bb40b458d68 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -403,6 +403,22 @@ static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
}
+static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
+ Value sparseOut, ValueRange ivs, Value v) {
+ scf::IfOp condInsert =
+ builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true);
+ // True branch.
+ builder.setInsertionPointToStart(condInsert.thenBlock());
+ Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs);
+ builder.create<scf::YieldOp>(loc, res);
+ // False branch.
+ builder.setInsertionPointToStart(condInsert.elseBlock());
+ builder.create<scf::YieldOp>(loc, sparseOut);
+ // Value assignment.
+ builder.setInsertionPointAfter(condInsert);
+ return condInsert.getResult(0);
+}
+
/// Generates insertion code to implement dynamic tensor store.
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
Value rhs) {
@@ -423,23 +439,21 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
// return updated chain
// else
// return unmodified chain
- scf::IfOp ifValidLexInsert = builder.create<scf::IfOp>(
- loc, chain.getType(), env.getValidLexInsert(),
- /*else=*/true);
- // True branch.
- builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
- Value res = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
- builder.create<scf::YieldOp>(loc, res);
- // False branch.
- builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
- builder.create<scf::YieldOp>(loc, chain);
- // Value assignment.
- builder.setInsertionPointAfter(ifValidLexInsert);
- env.updateInsertionChain(ifValidLexInsert.getResult(0));
+ Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
+ chain, ivs, rhs);
+ env.updateInsertionChain(out);
} else {
+ Value sparseOut;
+ if (!hasAnySparseType(env.op().getInputs().getTypes())) {
+ // This is an all-dense -> sparse kernel, test rhs != 0 before
+ // insertion.
+ Value nz = genIsNonzero(builder, loc, rhs);
+ sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
+ } else {
+ sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
+ }
// Generates regular insertion chain.
- env.updateInsertionChain(
- builder.create<tensor::InsertOp>(loc, rhs, chain, ivs));
+ env.updateInsertionChain(sparseOut);
}
return;
}
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
new file mode 100644
index 00000000000000..077dde230fd156
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s
+
+#trait = {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+}
+
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
+
+// CHECK-LABEL: func.func @test(
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.if
+// CHECK-NEXT: tensor.insert
+// CHECK-NEXT: scf.yield
+// CHECK-NEXT: else
+// CHECK-NEXT: scf.yield
+// CHECK: scf.yield
+// CHECK: scf.yield
+// CHECK: scf.yield
+// CHECK: sparse_tensor.load
+func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #sparse> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %cst_0 = arith.constant 1.000000e+00 : f32
+ %cst_1 = arith.constant 1.000000e+00 : f32
+ %0 = tensor.empty() : tensor<128x32x32x1xf32>
+ %1 = linalg.generic #trait
+ ins(%arg0, %arg1, %arg2 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
+ outs(%0 : tensor<128x32x32x1xf32>) {
+ ^bb0(%in: f32, %in_2: f32, %in_3: f32, %out: f32):
+ %3 = arith.subf %cst_0, %in_2 : f32
+ %4 = arith.mulf %in, %3 : f32
+ %5 = arith.mulf %4, %cst_1 : f32
+ %6 = arith.addf %5, %in_3 : f32
+ %7 = arith.subf %6, %cst_0 : f32
+ %8 = arith.cmpf uge, %7, %cst : f32
+ %9 = arith.uitofp %8 : i1 to f32
+ linalg.yield %9 : f32
+ } -> tensor<128x32x32x1xf32>
+ %2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #sparse>
+ return %2 : tensor<128x32x32x1xf32, #sparse>
+}
diff --git a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
index bbc7f397e793fe..f2f64567d5bd01 100644
--- a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
+++ b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
@@ -24,7 +24,6 @@ module {
// CHECK: arith.constant
// CHECK: tensor.empty()
// CHECK: linalg.generic
- // CHECK: sparse_tensor.convert
// CHECK: return
//
func.func @avoid_fold(%0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
@@ -44,4 +43,3 @@ module {
return %cast : tensor<10x20x30xf64, #sparse>
}
}
-
>From c689fbbce4b6033f1ee6854fdfb3dd411303b16e Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 24 Apr 2024 23:25:10 +0000
Subject: [PATCH 2/2] address comments
---
.../Dialect/SparseTensor/IR/SparseTensor.h | 2 +-
.../Transforms/SparseTensorRewriting.cpp | 2 +-
.../fuse_sparse_convert_into_producer.mlir | 40 ++++++++++++++---
.../SparseTensor/no_fold_into_consumer.mlir | 45 -------------------
4 files changed, 37 insertions(+), 52 deletions(-)
delete mode 100644 mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 550e28813b4e9b..b182b4c72b9535 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -89,7 +89,7 @@ inline MemRefType getMemRefType(T &&t) {
/// Returns null-attribute for any type without an encoding.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
-/// Returns true iff MLIR operand has any sparse operand.
+/// Returns true iff the type range has any sparse tensor type.
inline bool hasAnySparseType(TypeRange types) {
return llvm::any_of(types, [](Type type) {
return getSparseTensorEncoding(type) != nullptr;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 641dcc61d7d09c..9a8c6422a7ff62 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -289,7 +289,7 @@ struct FuseExtractSliceWithConcat
}
};
-/// Rewriting rule that converts direct yield of zero with initial allocation.
+/// Rewriting rule that fuses sparse_tensor.convert into producer.
struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
public:
using OpRewritePattern::OpRewritePattern;
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
index 077dde230fd156..efa92e565ba575 100644
--- a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
@@ -1,3 +1,4 @@
+// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map | FileCheck %s --check-prefix=CHECK-FOLD
// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s
#trait = {
@@ -10,9 +11,12 @@
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
}
-#sparse = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-LABEL: func.func @test(
+#COO = #sparse_tensor.encoding<{map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa))}>
+#CCCD = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
+
+// CHECK-LABEL: func.func @fold_convert(
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
@@ -25,7 +29,10 @@
// CHECK: scf.yield
// CHECK: scf.yield
// CHECK: sparse_tensor.load
-func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #sparse> {
+
+// CHECK-FOLD-LABEL: func.func @fold_convert(
+// CHECK-FOLD-NOT: sparse_tensor.convert
+func.func @fold_convert(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #CCCD> {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.000000e+00 : f32
%cst_1 = arith.constant 1.000000e+00 : f32
@@ -43,6 +50,29 @@ func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>,
%9 = arith.uitofp %8 : i1 to f32
linalg.yield %9 : f32
} -> tensor<128x32x32x1xf32>
- %2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #sparse>
- return %2 : tensor<128x32x32x1xf32, #sparse>
+ %2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #CCCD>
+ return %2 : tensor<128x32x32x1xf32, #CCCD>
+}
+
+
+// FIXME: The following kernel is not sparsifiable because `arith.select`
+// operations is not handled by the sparse compiler at the moment.
+//
+// CHECK-FOLD-LABEL: func.func @fold_cast(
+// CHECK-FOLD-NOT: sparse_tensor.convert
+func.func @fold_cast(%0: tensor<10x20x30xf64, #COO>) -> tensor<10x20x30xf64, #COO> {
+ %cst = arith.constant 0.000000e+00 : f64
+ %1 = tensor.empty() : tensor<10x20x30xf64>
+ %2 = linalg.generic { indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ }
+ ins (%0 : tensor<10x20x30xf64, #COO>)
+ outs(%1 : tensor<10x20x30xf64>) {
+ ^bb0(%in: f64, %out: f64):
+ %4 = arith.cmpf ugt, %in, %cst : f64
+ %5 = arith.select %4, %in, %cst : f64
+ linalg.yield %5 : f64
+ } -> tensor<10x20x30xf64>
+ %cast = tensor.cast %2 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #COO>
+ return %cast : tensor<10x20x30xf64, #COO>
}
diff --git a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
deleted file mode 100644
index f2f64567d5bd01..00000000000000
--- a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
+++ /dev/null
@@ -1,45 +0,0 @@
-// RUN: mlir-opt %s --canonicalize --pre-sparsification-rewrite | FileCheck %s
-
-#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
-#sparse = #sparse_tensor.encoding<{
- map = (d0, d1, d2) ->
- (d0 : compressed(nonunique),
- d1 : singleton(nonunique, soa),
- d2 : singleton(soa)),
- posWidth = 64,
- crdWidth = 64
-}>
-
-
-module {
- //
- // This IR should not end up in an infinite loop trying to fold
- // the linalg producer into the tensor cast consumer (even though
- // static sizes can fold, the different encodings cannot). The
- // cast was sloppy to begin with (but it has been observed by
- // external sources) and can be easily repaired by the sparsifier.
- //
- // CHECK-LABEL: func @avoid_fold
- // CHECK: arith.constant
- // CHECK: tensor.empty()
- // CHECK: linalg.generic
- // CHECK: return
- //
- func.func @avoid_fold(%0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
- %1 = tensor.empty() : tensor<10x20x30xf64>
- %2 = linalg.generic { indexing_maps = [#map, #map],
- iterator_types = ["parallel", "parallel", "parallel"]
- }
- ins (%0 : tensor<10x20x30xf64, #sparse>)
- outs(%1 : tensor<10x20x30xf64>) {
- ^bb0(%in: f64, %out: f64):
- %cst = arith.constant 0.000000e+00 : f64
- %4 = arith.cmpf ugt, %in, %cst : f64
- %5 = arith.select %4, %in, %cst : f64
- linalg.yield %5 : f64
- } -> tensor<10x20x30xf64>
- %cast = tensor.cast %2 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #sparse>
- return %cast : tensor<10x20x30xf64, #sparse>
- }
-}
More information about the Mlir-commits
mailing list