[Mlir-commits] [mlir] [mlir][SparseTensor] Fix invalid API usage in patterns (PR #74690)

Matthias Springer llvmlistbot at llvm.org
Wed Dec 6 18:54:59 PST 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/74690

Rewrite patterns must return `success` if the IR was modified. This commit fixes sparse tensor tests such as `SparseTensor/sparse_fusion.mlir`, `SparseTensor/CPU/sparse_reduce_custom.mlir`, `SparseTensor/CPU/sparse_semiring_select.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.

>From 06d57b0aff9844182797118db456b89d6045b266 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 7 Dec 2023 11:53:34 +0900
Subject: [PATCH] [mlir][SparseTensor] Fix invalid API usage in patterns

Rewrite patterns must return `success` if the IR was modified. This commit fixes sparse tensor tests such as `SparseTensor/sparse_fusion.mlir`, `SparseTensor/CPU/sparse_reduce_custom.mlir`, `SparseTensor/CPU/sparse_semiring_select.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 11 +++++-----
 .../Transforms/SparseReinterpretMap.cpp       | 20 +++++++++++++------
 2 files changed, 20 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f0393e44fc00c..dc5ea28b67cdc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -422,11 +422,6 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
       if (!controlFn(&opOperand))
         continue;
 
-      // Find the producer of the operand.
-      FailureOr<ElementwiseOpFusionResult> fusionResult =
-          fuseElementwiseOps(rewriter, &opOperand);
-      if (failed(fusionResult))
-        return rewriter.notifyMatchFailure(genericOp, "fusion failed");
       Operation *producer = opOperand.get().getDefiningOp();
 
       // Do not fuse a sparse-in/dense-out operation, as the
@@ -435,6 +430,12 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
           !sparse_tensor::hasAnySparseResult(producer))
         return failure();
 
+      // Find the producer of the operand.
+      FailureOr<ElementwiseOpFusionResult> fusionResult =
+          fuseElementwiseOps(rewriter, &opOperand);
+      if (failed(fusionResult))
+        return rewriter.notifyMatchFailure(genericOp, "fusion failed");
+
       // Perform the fusion.
       for (auto [origVal, replacement] : fusionResult->replacements) {
         rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index c94ef8b962877..488079cfe4e32 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -38,16 +38,22 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
   LogicalResult matchAndRewrite(SourceOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
+
     // Demaps non-trivial inputs.
+    bool changed = false;
     SmallVector<Value> deMappedIns(op->getOperands());
-    for (Value &in : deMappedIns)
-      if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
+    for (Value &in : deMappedIns) {
+      if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
         in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
+        changed = true;
+      }
+    }
 
     // CRTP call.
     OpAdaptor adaptor(deMappedIns, op);
-    return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
-                                                          rewriter);
+    LogicalResult status =
+        static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
+    return changed ? success() : status;
   }
 };
 
@@ -452,11 +458,13 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
     }
 
     // Marks the GenericOp to avoid recursive matching.
-    linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+    rewriter.updateRootInPlace(linalgOp, [&]() {
+      linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+    });
 
     // Already sorted.
     if (order.isIdentity())
-      return failure();
+      return success();
 
     assert(order.isPermutation());
     // `order` is orignial loop -> sorted loop map



More information about the Mlir-commits mailing list