[Mlir-commits] [mlir] [mlir][SparseTensor] Fix invalid API usage in patterns (PR #74690)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 6 18:55:30 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Rewrite patterns must return `success` if the IR was modified. This commit fixes sparse tensor tests such as `SparseTensor/sparse_fusion.mlir`, `SparseTensor/CPU/sparse_reduce_custom.mlir`, `SparseTensor/CPU/sparse_semiring_select.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.

---
Full diff: https://github.com/llvm/llvm-project/pull/74690.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+6-5) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (+14-6) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f0393e44fc00c..dc5ea28b67cdc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -422,11 +422,6 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
       if (!controlFn(&opOperand))
         continue;
 
-      // Find the producer of the operand.
-      FailureOr<ElementwiseOpFusionResult> fusionResult =
-          fuseElementwiseOps(rewriter, &opOperand);
-      if (failed(fusionResult))
-        return rewriter.notifyMatchFailure(genericOp, "fusion failed");
       Operation *producer = opOperand.get().getDefiningOp();
 
       // Do not fuse a sparse-in/dense-out operation, as the
@@ -435,6 +430,12 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
           !sparse_tensor::hasAnySparseResult(producer))
         return failure();
 
+      // Find the producer of the operand.
+      FailureOr<ElementwiseOpFusionResult> fusionResult =
+          fuseElementwiseOps(rewriter, &opOperand);
+      if (failed(fusionResult))
+        return rewriter.notifyMatchFailure(genericOp, "fusion failed");
+
       // Perform the fusion.
       for (auto [origVal, replacement] : fusionResult->replacements) {
         rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index c94ef8b962877..488079cfe4e32 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -38,16 +38,22 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
   LogicalResult matchAndRewrite(SourceOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
+
     // Demaps non-trivial inputs.
+    bool changed = false;
     SmallVector<Value> deMappedIns(op->getOperands());
-    for (Value &in : deMappedIns)
-      if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
+    for (Value &in : deMappedIns) {
+      if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
         in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
+        changed = true;
+      }
+    }
 
     // CRTP call.
     OpAdaptor adaptor(deMappedIns, op);
-    return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
-                                                          rewriter);
+    LogicalResult status =
+        static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
+    return changed ? success() : status;
   }
 };
 
@@ -452,11 +458,13 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
     }
 
     // Marks the GenericOp to avoid recursive matching.
-    linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+    rewriter.updateRootInPlace(linalgOp, [&]() {
+      linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+    });
 
     // Already sorted.
     if (order.isIdentity())
-      return failure();
+      return success();
 
     assert(order.isPermutation());
     // `order` is orignial loop -> sorted loop map

``````````

</details>


https://github.com/llvm/llvm-project/pull/74690


More information about the Mlir-commits mailing list