[Mlir-commits] [mlir] [mlir][sparse] make sparse compiler more admissible. (PR #90927)
Peiming Liu
llvmlistbot at llvm.org
Thu May 2 18:43:59 PDT 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/90927
>From 50ea686d579db232543d3efcaffd8c4cf9f5be8f Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 3 May 2024 01:37:54 +0000
Subject: [PATCH] [mlir][sparse] make sparse compiler more admissible.
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 6 --
.../lib/Dialect/SparseTensor/Utils/Merger.cpp | 54 ++++++++-------
.../Dialect/SparseTensor/sparse_fusion.mlir | 69 ++++++++++++++-----
3 files changed, 79 insertions(+), 50 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 89fb4944c0ca3c..ad313c2d5ce603 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -432,12 +432,6 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
Operation *producer = opOperand.get().getDefiningOp();
- // Do not fuse a sparse-in/dense-out operation, as the
- // result is too often not sparsifiable anymore.
- if (sparse_tensor::hasAnySparseOperand(producer) &&
- !sparse_tensor::hasAnySparseResult(producer))
- return failure();
-
// Find the producer of the operand.
FailureOr<ElementwiseOpFusionResult> fusionResult =
fuseElementwiseOps(rewriter, &opOperand);
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 9c0aed3c18eff2..308fbd965259db 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1356,50 +1356,54 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
// See buildLattices() for an explanation of rejecting certain
// division and shift operations.
if (def->getNumOperands() == 2) {
- const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
- const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
- bool hasSpDep = xDepSp || yDepSp;
+ const auto [x, xSpVals] = buildTensorExp(op, def->getOperand(0));
+ const auto [y, ySpVals] = buildTensorExp(op, def->getOperand(1));
+ // For a conjunctive operation, it yields a "sparse" result if any operand
+ // is sparse. For a disjunctive operation, it yields a "sparse" result if
+ // all operands are sparse.
+ bool conjSpVals = xSpVals || ySpVals;
+ bool disjSpVals = xSpVals && ySpVals;
if (x.has_value() && y.has_value()) {
const ExprId e0 = *x;
const ExprId e1 = *y;
if (isa<arith::MulFOp>(def))
- return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kMulF, e0, e1), conjSpVals};
if (isa<complex::MulOp>(def))
- return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kMulC, e0, e1), conjSpVals};
if (isa<arith::MulIOp>(def))
- return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kMulI, e0, e1), conjSpVals};
if (isa<arith::DivFOp>(def) && !maybeZero(e1))
- return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kDivF, e0, e1), conjSpVals};
if (isa<complex::DivOp>(def) && !maybeZero(e1))
- return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kDivC, e0, e1), conjSpVals};
if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
- return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kDivS, e0, e1), conjSpVals};
if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
- return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kDivU, e0, e1), conjSpVals};
if (isa<arith::AddFOp>(def))
- return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kAddF, e0, e1), disjSpVals};
if (isa<complex::AddOp>(def))
- return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kAddC, e0, e1), disjSpVals};
if (isa<arith::AddIOp>(def))
- return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kAddI, e0, e1), disjSpVals};
if (isa<arith::SubFOp>(def))
- return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kSubF, e0, e1), disjSpVals};
if (isa<complex::SubOp>(def))
- return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kSubC, e0, e1), disjSpVals};
if (isa<arith::SubIOp>(def))
- return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kSubI, e0, e1), disjSpVals};
if (isa<arith::AndIOp>(def))
- return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kAndI, e0, e1), conjSpVals};
if (isa<arith::OrIOp>(def))
- return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kOrI, e0, e1), disjSpVals};
if (isa<arith::XOrIOp>(def))
- return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kXorI, e0, e1), disjSpVals};
if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
- return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kShrS, e0, e1), conjSpVals};
if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
- return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kShrU, e0, e1), conjSpVals};
if (isa<arith::ShLIOp>(def) && isInvariant(e1))
- return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kShlI, e0, e1), conjSpVals};
if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
if (ci.getPredicate() == arith::CmpIPredicate::eq &&
ci.getPredicate() == arith::CmpIPredicate::sle &&
@@ -1413,7 +1417,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
ci.getPredicateAttr());
- return {e, hasSpDep};
+ return {e, conjSpVals};
}
if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
@@ -1431,7 +1435,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
}
auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
cf.getPredicateAttr());
- return {e, hasSpDep};
+ return {e, conjSpVals};
}
if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
@@ -1439,7 +1443,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
isAdmissibleBranch(binop, binop.getLeftRegion())) &&
(binop.getRightIdentity() ||
isAdmissibleBranch(binop, binop.getRightRegion())))
- return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep};
+ return {addExp(TensorExp::Kind::kBinary, e0, e1, def), conjSpVals};
}
}
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
index 8780baac199e16..2cc64434a1d8f2 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --linalg-fuse-elementwise-ops | FileCheck %s
+// RUN: mlir-opt %s --linalg-fuse-elementwise-ops --sparse-reinterpret-map --sparsification | FileCheck %s
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
@@ -11,22 +11,59 @@
doc = "B(i) = OP A(i)"
}
-// CHECK-LABEL: func @sparse_fusion
-// CHECK: linalg.generic
-// CHECK: arith.addf
-// CHECK: linalg.generic
-// CHECK: math.exp
-// CHECK: arith.maximumf
-// CHECK-NOT: linalg.generic
-// CHECK: return
+
+// CHECK-LABEL: func.func @sparse_fusion(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<100xf64, #sparse>) -> tensor<100xf64> {
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant true
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 100 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1.000000e+00 : f64
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1.000000e+02 : f64
+// CHECK-DAG: %[[VAL_8:.*]] = tensor.empty() : tensor<100xf64>
+// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
+// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
+// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<100xf64, #sparse> to memref<?xf64>
+// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_8]] : memref<100xf64>
+// CHECK: linalg.fill ins(%[[VAL_4]] : f64) outs(%[[VAL_12]] : memref<100xf64>)
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_3]]) : (index, index) -> (index, index) {
+// CHECK: %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_14]] : index
+// CHECK: scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index):
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK: scf.if %[[VAL_22]] {
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<?xf64>
+// CHECK: %[[VAL_24:.*]] = arith.addf %[[VAL_23]], %[[VAL_6]] : f64
+// CHECK: %[[VAL_25:.*]] = math.exp %[[VAL_24]] : f64
+// CHECK: %[[VAL_26:.*]] = arith.maximumf %[[VAL_25]], %[[VAL_7]] : f64
+// CHECK: memref.store %[[VAL_26]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
+// CHECK: } else {
+// CHECK: scf.if %[[VAL_1]] {
+// CHECK: memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
+// CHECK: } else {
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_19]], %[[VAL_2]] : index
+// CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_19]] : index
+// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_20]], %[[VAL_2]] : index
+// CHECK: scf.yield %[[VAL_29]], %[[VAL_30]] : index, index
+// CHECK: }
+// CHECK: scf.for %[[VAL_31:.*]] = %[[VAL_32:.*]]#1 to %[[VAL_5]] step %[[VAL_2]] {
+// CHECK: memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_31]]] : memref<100xf64>
+// CHECK: }
+// CHECK: %[[VAL_33:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<100xf64>
+// CHECK: return %[[VAL_33]] : tensor<100xf64>
+// CHECK: }
func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
%c1 = arith.constant 1.0 : f64
%c100 = arith.constant 100.0 : f64
- //
- // Densifying op.
- // Should not be fused with subsequent dense ops.
- //
%t0 = tensor.empty() : tensor<100xf64>
%l0 = linalg.generic #trait
ins(%argA: tensor<100xf64, #SV>) outs(%t0: tensor<100xf64>) {
@@ -34,12 +71,6 @@ func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
%b0 = arith.addf %in0, %c1 : f64
linalg.yield %b0 : f64
} -> tensor<100xf64>
-
-
- //
- // Two following dense ops.
- // Should be fused, but not with above.
- //
%t1 = tensor.empty() : tensor<100xf64>
%l1 = linalg.generic #trait
ins(%l0: tensor<100xf64>) outs(%t1: tensor<100xf64>) {
More information about the Mlir-commits
mailing list