[Mlir-commits] [mlir] [mlir][SparseTensor] Fix invalid API usage in patterns (PR #74690)
Matthias Springer
llvmlistbot at llvm.org
Wed Dec 6 18:54:59 PST 2023
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/74690
Rewrite patterns must return `success` if the IR was modified. This commit fixes sparse tensor tests such as `SparseTensor/sparse_fusion.mlir`, `SparseTensor/CPU/sparse_reduce_custom.mlir`, `SparseTensor/CPU/sparse_semiring_select.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.
>From 06d57b0aff9844182797118db456b89d6045b266 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 7 Dec 2023 11:53:34 +0900
Subject: [PATCH] [mlir][SparseTensor] Fix invalid API usage in patterns
Rewrite patterns must return `success` if the IR was modified. This commit fixes sparse tensor tests such as `SparseTensor/sparse_fusion.mlir`, `SparseTensor/CPU/sparse_reduce_custom.mlir`, `SparseTensor/CPU/sparse_semiring_select.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 11 +++++-----
.../Transforms/SparseReinterpretMap.cpp | 20 +++++++++++++------
2 files changed, 20 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f0393e44fc00c..dc5ea28b67cdc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -422,11 +422,6 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
if (!controlFn(&opOperand))
continue;
- // Find the producer of the operand.
- FailureOr<ElementwiseOpFusionResult> fusionResult =
- fuseElementwiseOps(rewriter, &opOperand);
- if (failed(fusionResult))
- return rewriter.notifyMatchFailure(genericOp, "fusion failed");
Operation *producer = opOperand.get().getDefiningOp();
// Do not fuse a sparse-in/dense-out operation, as the
@@ -435,6 +430,12 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
!sparse_tensor::hasAnySparseResult(producer))
return failure();
+ // Find the producer of the operand.
+ FailureOr<ElementwiseOpFusionResult> fusionResult =
+ fuseElementwiseOps(rewriter, &opOperand);
+ if (failed(fusionResult))
+ return rewriter.notifyMatchFailure(genericOp, "fusion failed");
+
// Perform the fusion.
for (auto [origVal, replacement] : fusionResult->replacements) {
rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index c94ef8b962877..488079cfe4e32 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -38,16 +38,22 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
+
// Demaps non-trivial inputs.
+ bool changed = false;
SmallVector<Value> deMappedIns(op->getOperands());
- for (Value &in : deMappedIns)
- if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
+ for (Value &in : deMappedIns) {
+ if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
+ changed = true;
+ }
+ }
// CRTP call.
OpAdaptor adaptor(deMappedIns, op);
- return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
- rewriter);
+ LogicalResult status =
+ static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
+ return changed ? success() : status;
}
};
@@ -452,11 +458,13 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
}
// Marks the GenericOp to avoid recursive matching.
- linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+ rewriter.updateRootInPlace(linalgOp, [&]() {
+ linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+ });
// Already sorted.
if (order.isIdentity())
- return failure();
+ return success();
assert(order.isPermutation());
// `order` is orignial loop -> sorted loop map
More information about the Mlir-commits
mailing list