[Mlir-commits] [mlir] 5c03c05 - [mlir][sparse] enhance element-wise fusion heuristics

Aart Bik llvmlistbot at llvm.org
Thu Jun 15 16:48:47 PDT 2023


Author: Aart Bik
Date: 2023-06-15T16:48:40-07:00
New Revision: 5c03c056e00e06cd722acd645a5ef4d65dd9c168

URL: https://github.com/llvm/llvm-project/commit/5c03c056e00e06cd722acd645a5ef4d65dd9c168
DIFF: https://github.com/llvm/llvm-project/commit/5c03c056e00e06cd722acd645a5ef4d65dd9c168.diff

LOG: [mlir][sparse] enhance element-wise fusion heuristics

We prevent merging a sparse-in/dense-out with dense-in
kernels because the result is usuall not sparsifiable.
Dense kernels and sparse kernels are still fused, obviously.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D153077

Added: 
    mlir/test/Dialect/SparseTensor/sparse_fusion.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 2e3753cbae8d4..9cab6b6a027cd 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -140,15 +140,23 @@ RankedTensorType getCOOFromTypeWithOrdering(RankedTensorType src,
 
 RankedTensorType getCOOFromType(RankedTensorType src, bool ordered);
 
-/// Returns true iff MLIR operand has any sparse operand or result.
-inline bool hasAnySparseOperandOrResult(Operation *op) {
-  bool anySparseIn = llvm::any_of(op->getOperands().getTypes(), [](Type t) {
+/// Returns true iff MLIR operand has any sparse operand.
+inline bool hasAnySparseOperand(Operation *op) {
+  return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
     return getSparseTensorEncoding(t) != nullptr;
   });
-  bool anySparseOut = llvm::any_of(op->getResults().getTypes(), [](Type t) {
+}
+
+/// Returns true iff MLIR operand has any sparse result.
+inline bool hasAnySparseResult(Operation *op) {
+  return llvm::any_of(op->getResults().getTypes(), [](Type t) {
     return getSparseTensorEncoding(t) != nullptr;
   });
-  return anySparseIn || anySparseOut;
+}
+
+/// Returns true iff MLIR operand has any sparse operand or result.
+inline bool hasAnySparseOperandOrResult(Operation *op) {
+  return hasAnySparseOperand(op) || hasAnySparseResult(op);
 }
 
 //

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6a800578f4ac8..528fc477422db 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -422,11 +422,20 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
       if (!controlFn(&opOperand))
         continue;
 
+      // Find the producer of the operand.
       FailureOr<ElementwiseOpFusionResult> fusionResult =
           fuseElementwiseOps(rewriter, &opOperand);
       if (failed(fusionResult))
         return rewriter.notifyMatchFailure(genericOp, "fusion failed");
       Operation *producer = opOperand.get().getDefiningOp();
+
+      // Do not fuse a sparse-in/dense-out operation, as the
+      // result is too often not sparsifiable anymore.
+      if (sparse_tensor::hasAnySparseOperand(producer) &&
+          !sparse_tensor::hasAnySparseResult(producer))
+        return failure();
+
+      // Perform the fusion.
       for (auto [origVal, replacement] : fusionResult->replacements) {
         rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
           // Only replace consumer uses.

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
new file mode 100644
index 0000000000000..d6f4ca58ac642
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
@@ -0,0 +1,59 @@
+// RUN: mlir-opt %s --linalg-fuse-elementwise-ops | FileCheck %s
+
+#SV = #sparse_tensor.encoding<{ lvlTypes = ["compressed"] }>
+
+#trait = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>, // A
+    affine_map<(i) -> (i)>  // B (out)
+  ],
+  iterator_types = ["parallel"],
+  doc = "B(i) = OP A(i)"
+}
+
+// CHECK-LABEL: func @sparse_fusion
+// CHECK:     linalg.generic
+// CHECK:       arith.addf
+// CHECK:     linalg.generic
+// CHECK:       math.exp
+// CHECK:       arith.maxf
+// CHECK-NOT: linalg.generic
+// CHECK:     return
+func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
+  %c1 = arith.constant 1.0 : f64
+  %c100 = arith.constant 100.0 : f64
+
+  //
+  // Densifying op.
+  // Should not be fused with subsequent dense ops.
+  //
+  %t0 = tensor.empty() : tensor<100xf64>
+  %l0 = linalg.generic #trait
+      ins(%argA: tensor<100xf64, #SV>) outs(%t0: tensor<100xf64>) {
+    ^bb0(%in0: f64, %out0: f64):
+      %b0 = arith.addf %in0, %c1 : f64
+      linalg.yield %b0 : f64
+  } -> tensor<100xf64>
+
+
+  //
+  // Two following dense ops.
+  // Should be fused, but not with above.
+  //
+  %t1 = tensor.empty() : tensor<100xf64>
+  %l1 = linalg.generic #trait
+      ins(%l0: tensor<100xf64>) outs(%t1: tensor<100xf64>) {
+    ^bb0(%in1: f64, %out1: f64):
+      %b1 = math.exp %in1 : f64
+      linalg.yield %b1 : f64
+  } -> tensor<100xf64>
+  %t2 = tensor.empty() : tensor<100xf64>
+  %l2 = linalg.generic #trait
+      ins(%l1: tensor<100xf64>) outs(%t2: tensor<100xf64>) {
+    ^bb0(%in2: f64, %out2: f64):
+      %b2 = arith.maxf %in2, %c100 : f64
+      linalg.yield %b2 : f64
+  } -> tensor<100xf64>
+
+  return %l2 : tensor<100xf64>
+}


        


More information about the Mlir-commits mailing list