[Mlir-commits] [mlir] [MLIR][Shard] Fix NormalizeSharding and FoldDuplicateShardOp direct mutations (PR #188981)

Mehdi Amini llvmlistbot at llvm.org
Fri Mar 27 09:10:27 PDT 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/188981

>From c2bafe5b669c1d462a2260f2cbf5815bf6d9b886 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Mar 2026 15:57:21 -0700
Subject: [PATCH] [MLIR][Shard] Fix NormalizeSharding and FoldDuplicateShardOp
 direct mutations

NormalizeSharding::matchAndRewrite was directly calling attribute setters
and MutableOperandRange::assign() without going through the PatternRewriter,
bypassing the rewriter's change-tracking. Under
MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS this triggered "operation finger
print changed" after the pattern returned success.

Similarly, FoldDuplicateShardOp::matchAndRewrite was directly calling
op.getSrcMutable().assign() in the else-branch without notifying the
rewriter, causing the same fingerprint-change error.

Fix: wrap both direct mutations in rewriter.modifyOpInPlace() so the
rewriter is properly notified of in-place changes.

Assisted-by: Claude Code
Fix a failure present with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON.
---
 mlir/lib/Dialect/Shard/IR/ShardOps.cpp | 22 ++++++++++++++--------
 1 file changed, 14 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
index a173da3db1d18..ff790a0bf961d 100644
--- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
+++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
@@ -631,8 +631,12 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
     bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
                     succeeded(foldDynamicIndexList(mixedOffs, true));
 
-    auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos);
-    auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs);
+    auto decomposedHalos = decomposeMixedValues(mixedHalos);
+    auto staticHalos = decomposedHalos.first;
+    auto dynamicHalos = decomposedHalos.second;
+    auto decomposedOffs = decomposeMixedValues(mixedOffs);
+    auto staticOffs = decomposedOffs.first;
+    auto dynamicOffs = decomposedOffs.second;
 
     if (dynamicHalos.empty() && !staticHalos.empty()) {
       if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
@@ -666,11 +670,12 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
       return failure();
     }
 
-    op.setStaticHaloSizes(staticHalos);
-    op.getDynamicHaloSizesMutable().assign(dynamicHalos);
-    op.setStaticShardedDimsOffsets(staticOffs);
-    op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
-
+    b.modifyOpInPlace(op, [&]() {
+      op.setStaticHaloSizes(staticHalos);
+      op.getDynamicHaloSizesMutable().assign(dynamicHalos);
+      op.setStaticShardedDimsOffsets(staticOffs);
+      op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
+    });
     return success();
   }
 };
@@ -877,7 +882,8 @@ class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
           b.eraseOp(op.getOperation());
         } else {
           // use the other sharding as input for op
-          op.getSrcMutable().assign(otherOp.getResult());
+          b.modifyOpInPlace(
+              op, [&]() { op.getSrcMutable().assign(otherOp.getResult()); });
         }
         return success();
       }



More information about the Mlir-commits mailing list