[Mlir-commits] [mlir] [MLIR][Shard] Fix NormalizeSharding and FoldDuplicateShardOp direct mutations (PR #188981)
Mehdi Amini
llvmlistbot at llvm.org
Fri Mar 27 09:10:27 PDT 2026
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/188981
>From c2bafe5b669c1d462a2260f2cbf5815bf6d9b886 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Mar 2026 15:57:21 -0700
Subject: [PATCH] [MLIR][Shard] Fix NormalizeSharding and FoldDuplicateShardOp
direct mutations
NormalizeSharding::matchAndRewrite was directly calling attribute setters
and MutableOperandRange::assign() without going through the PatternRewriter,
bypassing the rewriter's change-tracking. Under
MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS this triggered "operation finger
print changed" after the pattern returned success.
Similarly, FoldDuplicateShardOp::matchAndRewrite was directly calling
op.getSrcMutable().assign() in the else-branch without notifying the
rewriter, causing the same fingerprint-change error.
Fix: wrap both direct mutations in rewriter.modifyOpInPlace() so the
rewriter is properly notified of in-place changes.
Assisted-by: Claude Code
Fix a failure present with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON.
---
mlir/lib/Dialect/Shard/IR/ShardOps.cpp | 22 ++++++++++++++--------
1 file changed, 14 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
index a173da3db1d18..ff790a0bf961d 100644
--- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
+++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
@@ -631,8 +631,12 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
succeeded(foldDynamicIndexList(mixedOffs, true));
- auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos);
- auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs);
+ auto decomposedHalos = decomposeMixedValues(mixedHalos);
+ auto staticHalos = decomposedHalos.first;
+ auto dynamicHalos = decomposedHalos.second;
+ auto decomposedOffs = decomposeMixedValues(mixedOffs);
+ auto staticOffs = decomposedOffs.first;
+ auto dynamicOffs = decomposedOffs.second;
if (dynamicHalos.empty() && !staticHalos.empty()) {
if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
@@ -666,11 +670,12 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
return failure();
}
- op.setStaticHaloSizes(staticHalos);
- op.getDynamicHaloSizesMutable().assign(dynamicHalos);
- op.setStaticShardedDimsOffsets(staticOffs);
- op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
-
+ b.modifyOpInPlace(op, [&]() {
+ op.setStaticHaloSizes(staticHalos);
+ op.getDynamicHaloSizesMutable().assign(dynamicHalos);
+ op.setStaticShardedDimsOffsets(staticOffs);
+ op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
+ });
return success();
}
};
@@ -877,7 +882,8 @@ class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
b.eraseOp(op.getOperation());
} else {
// use the other sharding as input for op
- op.getSrcMutable().assign(otherOp.getResult());
+ b.modifyOpInPlace(
+ op, [&]() { op.getSrcMutable().assign(otherOp.getResult()); });
}
return success();
}
More information about the Mlir-commits
mailing list