[Mlir-commits] [mlir] [mlir][SparseTensor] Fix invalid API usage in patterns (PR #74690)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 6 18:55:30 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Rewrite patterns must return `success` if the IR was modified. This commit fixes sparse tensor tests such as `SparseTensor/sparse_fusion.mlir`, `SparseTensor/CPU/sparse_reduce_custom.mlir`, `SparseTensor/CPU/sparse_semiring_select.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.
---
Full diff: https://github.com/llvm/llvm-project/pull/74690.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+6-5)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (+14-6)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f0393e44fc00c..dc5ea28b67cdc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -422,11 +422,6 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
if (!controlFn(&opOperand))
continue;
- // Find the producer of the operand.
- FailureOr<ElementwiseOpFusionResult> fusionResult =
- fuseElementwiseOps(rewriter, &opOperand);
- if (failed(fusionResult))
- return rewriter.notifyMatchFailure(genericOp, "fusion failed");
Operation *producer = opOperand.get().getDefiningOp();
// Do not fuse a sparse-in/dense-out operation, as the
@@ -435,6 +430,12 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
!sparse_tensor::hasAnySparseResult(producer))
return failure();
+ // Find the producer of the operand.
+ FailureOr<ElementwiseOpFusionResult> fusionResult =
+ fuseElementwiseOps(rewriter, &opOperand);
+ if (failed(fusionResult))
+ return rewriter.notifyMatchFailure(genericOp, "fusion failed");
+
// Perform the fusion.
for (auto [origVal, replacement] : fusionResult->replacements) {
rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index c94ef8b962877..488079cfe4e32 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -38,16 +38,22 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
+
// Demaps non-trivial inputs.
+ bool changed = false;
SmallVector<Value> deMappedIns(op->getOperands());
- for (Value &in : deMappedIns)
- if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
+ for (Value &in : deMappedIns) {
+ if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
+ changed = true;
+ }
+ }
// CRTP call.
OpAdaptor adaptor(deMappedIns, op);
- return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
- rewriter);
+ LogicalResult status =
+ static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
+ return changed ? success() : status;
}
};
@@ -452,11 +458,13 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
}
// Marks the GenericOp to avoid recursive matching.
- linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+ rewriter.updateRootInPlace(linalgOp, [&]() {
+ linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+ });
// Already sorted.
if (order.isIdentity())
- return failure();
+ return success();
assert(order.isPermutation());
// `order` is orignial loop -> sorted loop map
``````````
</details>
https://github.com/llvm/llvm-project/pull/74690
More information about the Mlir-commits
mailing list