[Mlir-commits] [mlir] 986287e - [mlir][SparseTensor] Fix invalid API usage in patterns (#74690)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 6 19:05:24 PST 2023


Author: Matthias Springer
Date: 2023-12-07T12:05:20+09:00
New Revision: 986287e7f38321165c0c654f3af06e34af7b161f

URL: https://github.com/llvm/llvm-project/commit/986287e7f38321165c0c654f3af06e34af7b161f
DIFF: https://github.com/llvm/llvm-project/commit/986287e7f38321165c0c654f3af06e34af7b161f.diff

LOG: [mlir][SparseTensor] Fix invalid API usage in patterns (#74690)

Rewrite patterns must return `success` if the IR was modified. This
commit fixes sparse tensor tests such as
`SparseTensor/sparse_fusion.mlir`,
`SparseTensor/CPU/sparse_reduce_custom.mlir`,
`SparseTensor/CPU/sparse_semiring_select.mlir` when running with
`MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f0393e44fc00c..dc5ea28b67cdc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -422,11 +422,6 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
       if (!controlFn(&opOperand))
         continue;
 
-      // Find the producer of the operand.
-      FailureOr<ElementwiseOpFusionResult> fusionResult =
-          fuseElementwiseOps(rewriter, &opOperand);
-      if (failed(fusionResult))
-        return rewriter.notifyMatchFailure(genericOp, "fusion failed");
       Operation *producer = opOperand.get().getDefiningOp();
 
       // Do not fuse a sparse-in/dense-out operation, as the
@@ -435,6 +430,12 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
           !sparse_tensor::hasAnySparseResult(producer))
         return failure();
 
+      // Find the producer of the operand.
+      FailureOr<ElementwiseOpFusionResult> fusionResult =
+          fuseElementwiseOps(rewriter, &opOperand);
+      if (failed(fusionResult))
+        return rewriter.notifyMatchFailure(genericOp, "fusion failed");
+
       // Perform the fusion.
       for (auto [origVal, replacement] : fusionResult->replacements) {
         rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index c94ef8b962877..488079cfe4e32 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -38,16 +38,22 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
   LogicalResult matchAndRewrite(SourceOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
+
     // Demaps non-trivial inputs.
+    bool changed = false;
     SmallVector<Value> deMappedIns(op->getOperands());
-    for (Value &in : deMappedIns)
-      if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
+    for (Value &in : deMappedIns) {
+      if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
         in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
+        changed = true;
+      }
+    }
 
     // CRTP call.
     OpAdaptor adaptor(deMappedIns, op);
-    return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
-                                                          rewriter);
+    LogicalResult status =
+        static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
+    return changed ? success() : status;
   }
 };
 
@@ -452,11 +458,13 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
     }
 
     // Marks the GenericOp to avoid recursive matching.
-    linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+    rewriter.updateRootInPlace(linalgOp, [&]() {
+      linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+    });
 
     // Already sorted.
     if (order.isIdentity())
-      return failure();
+      return success();
 
     assert(order.isPermutation());
     // `order` is orignial loop -> sorted loop map


        


More information about the Mlir-commits mailing list