[Mlir-commits] [mlir] [mlir][sparse] make sparse compiler more admissible. (PR #90927)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 2 18:39:55 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/90927.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (-6) 
- (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+29-25) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_fusion.mlir (+50-19) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 89fb4944c0ca3c..ad313c2d5ce603 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -432,12 +432,6 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 
       Operation *producer = opOperand.get().getDefiningOp();
 
-      // Do not fuse a sparse-in/dense-out operation, as the
-      // result is too often not sparsifiable anymore.
-      if (sparse_tensor::hasAnySparseOperand(producer) &&
-          !sparse_tensor::hasAnySparseResult(producer))
-        return failure();
-
       // Find the producer of the operand.
       FailureOr<ElementwiseOpFusionResult> fusionResult =
           fuseElementwiseOps(rewriter, &opOperand);
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 9c0aed3c18eff2..308fbd965259db 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1356,50 +1356,54 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
   // See buildLattices() for an explanation of rejecting certain
   // division and shift operations.
   if (def->getNumOperands() == 2) {
-    const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
-    const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
-    bool hasSpDep = xDepSp || yDepSp;
+    const auto [x, xSpVals] = buildTensorExp(op, def->getOperand(0));
+    const auto [y, ySpVals] = buildTensorExp(op, def->getOperand(1));
+    // For a conjunctive operation, it yields a "sparse" result if any operand
+    // is sparse. For a disjunctive operation, it yields a "sparse" result if
+    // all operands are sparse.
+    bool conjSpVals = xSpVals || ySpVals;
+    bool disjSpVals = xSpVals && ySpVals;
     if (x.has_value() && y.has_value()) {
       const ExprId e0 = *x;
       const ExprId e1 = *y;
       if (isa<arith::MulFOp>(def))
-        return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kMulF, e0, e1), conjSpVals};
       if (isa<complex::MulOp>(def))
-        return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kMulC, e0, e1), conjSpVals};
       if (isa<arith::MulIOp>(def))
-        return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kMulI, e0, e1), conjSpVals};
       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
-        return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kDivF, e0, e1), conjSpVals};
       if (isa<complex::DivOp>(def) && !maybeZero(e1))
-        return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kDivC, e0, e1), conjSpVals};
       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
-        return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kDivS, e0, e1), conjSpVals};
       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
-        return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kDivU, e0, e1), conjSpVals};
       if (isa<arith::AddFOp>(def))
-        return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kAddF, e0, e1), disjSpVals};
       if (isa<complex::AddOp>(def))
-        return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kAddC, e0, e1), disjSpVals};
       if (isa<arith::AddIOp>(def))
-        return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kAddI, e0, e1), disjSpVals};
       if (isa<arith::SubFOp>(def))
-        return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kSubF, e0, e1), disjSpVals};
       if (isa<complex::SubOp>(def))
-        return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kSubC, e0, e1), disjSpVals};
       if (isa<arith::SubIOp>(def))
-        return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kSubI, e0, e1), disjSpVals};
       if (isa<arith::AndIOp>(def))
-        return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kAndI, e0, e1), conjSpVals};
       if (isa<arith::OrIOp>(def))
-        return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kOrI, e0, e1), disjSpVals};
       if (isa<arith::XOrIOp>(def))
-        return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kXorI, e0, e1), disjSpVals};
       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
-        return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kShrS, e0, e1), conjSpVals};
       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
-        return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kShrU, e0, e1), conjSpVals};
       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
-        return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kShlI, e0, e1), conjSpVals};
       if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
         if (ci.getPredicate() == arith::CmpIPredicate::eq &&
             ci.getPredicate() == arith::CmpIPredicate::sle &&
@@ -1413,7 +1417,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
 
         auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
                         ci.getPredicateAttr());
-        return {e, hasSpDep};
+        return {e, conjSpVals};
       }
       if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
         if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
@@ -1431,7 +1435,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
         }
         auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
                         cf.getPredicateAttr());
-        return {e, hasSpDep};
+        return {e, conjSpVals};
       }
       if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
         if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
@@ -1439,7 +1443,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
              isAdmissibleBranch(binop, binop.getLeftRegion())) &&
             (binop.getRightIdentity() ||
              isAdmissibleBranch(binop, binop.getRightRegion())))
-          return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep};
+          return {addExp(TensorExp::Kind::kBinary, e0, e1, def), conjSpVals};
       }
     }
   }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
index 8780baac199e16..2cc64434a1d8f2 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --linalg-fuse-elementwise-ops | FileCheck %s
+// RUN: mlir-opt %s --linalg-fuse-elementwise-ops --sparse-reinterpret-map --sparsification | FileCheck %s
 
 #SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
 
@@ -11,22 +11,59 @@
   doc = "B(i) = OP A(i)"
 }
 
-// CHECK-LABEL: func @sparse_fusion
-// CHECK:     linalg.generic
-// CHECK:       arith.addf
-// CHECK:     linalg.generic
-// CHECK:       math.exp
-// CHECK:       arith.maximumf
-// CHECK-NOT: linalg.generic
-// CHECK:     return
+
+// CHECK-LABEL:   func.func @sparse_fusion(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<100xf64, #sparse>) -> tensor<100xf64> {
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant true
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 100 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1.000000e+00 : f64
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1.000000e+02 : f64
+// CHECK-DAG:       %[[VAL_8:.*]] = tensor.empty() : tensor<100xf64>
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<100xf64, #sparse> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_8]] : memref<100xf64>
+// CHECK:           linalg.fill ins(%[[VAL_4]] : f64) outs(%[[VAL_12]] : memref<100xf64>)
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_3]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_14]] : index
+// CHECK:             scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index):
+// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK:             %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK:             scf.if %[[VAL_22]] {
+// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<?xf64>
+// CHECK:               %[[VAL_24:.*]] = arith.addf %[[VAL_23]], %[[VAL_6]] : f64
+// CHECK:               %[[VAL_25:.*]] = math.exp %[[VAL_24]] : f64
+// CHECK:               %[[VAL_26:.*]] = arith.maximumf %[[VAL_25]], %[[VAL_7]] : f64
+// CHECK:               memref.store %[[VAL_26]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
+// CHECK:             } else {
+// CHECK:               scf.if %[[VAL_1]] {
+// CHECK:                 memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
+// CHECK:               } else {
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK:             %[[VAL_28:.*]] = arith.addi %[[VAL_19]], %[[VAL_2]] : index
+// CHECK:             %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_30:.*]] = arith.addi %[[VAL_20]], %[[VAL_2]] : index
+// CHECK:             scf.yield %[[VAL_29]], %[[VAL_30]] : index, index
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_31:.*]] = %[[VAL_32:.*]]#1 to %[[VAL_5]] step %[[VAL_2]] {
+// CHECK:             memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_31]]] : memref<100xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_33:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<100xf64>
+// CHECK:           return %[[VAL_33]] : tensor<100xf64>
+// CHECK:         }
 func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
   %c1 = arith.constant 1.0 : f64
   %c100 = arith.constant 100.0 : f64
 
-  //
-  // Densifying op.
-  // Should not be fused with subsequent dense ops.
-  //
   %t0 = tensor.empty() : tensor<100xf64>
   %l0 = linalg.generic #trait
       ins(%argA: tensor<100xf64, #SV>) outs(%t0: tensor<100xf64>) {
@@ -34,12 +71,6 @@ func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
       %b0 = arith.addf %in0, %c1 : f64
       linalg.yield %b0 : f64
   } -> tensor<100xf64>
-
-
-  //
-  // Two following dense ops.
-  // Should be fused, but not with above.
-  //
   %t1 = tensor.empty() : tensor<100xf64>
   %l1 = linalg.generic #trait
       ins(%l0: tensor<100xf64>) outs(%t1: tensor<100xf64>) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/90927


More information about the Mlir-commits mailing list