[Mlir-commits] [mlir] [MLIR][MemRef] Fix DimOfReifyRankedShapedTypeOpInterface IR-change on failure (PR #188973)

Mehdi Amini llvmlistbot at llvm.org
Fri Mar 27 04:59:37 PDT 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/188973

>From 37fe4fdcf5bdacb106d74c50c33501120b34fb70 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Mar 2026 15:57:23 -0700
Subject: [PATCH] [MLIR][MemRef] Fix DimOfReifyRankedShapedTypeOpInterface
 IR-change on failure

DimOfReifyRankedShapedTypeOpInterface::matchAndRewrite called
reifyDimOfResult via the PatternRewriter. Some implementations delegate
to the coarse-grained reifyResultShapes, which creates ops for ALL
dimensions (e.g. a tensor.dim) before discovering that a specific
dimension is not reifiable (signalled by an empty OpFoldResult).

The pattern then returned failure() once it saw the empty OpFoldResult,
but the newly created ops were already in the IR. Under
MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS this triggered "pattern returned
failure but IR did change".

Fix: record the op immediately before the matched dim op, so we can
identify ops inserted during the reification attempt. If reification
returns an empty (unreifiable) OpFoldResult, erase those newly created
ops before returning failure, restoring the IR to its original state.

Assisted-by: Claude Code
Fix a failure present with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON.
---
 .../ResolveShapedTypeResultDims.cpp           | 30 ++++++++++++++++---
 .../resolve-shaped-type-result-dims.mlir      | 21 +++++++++++++
 2 files changed, 47 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index c498c8a60bf6e..cb0c03de266e1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -90,13 +90,35 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
     if (!dimIndex)
       return failure();
 
+    // Save the op immediately before dimOp so we can identify and erase any
+    // ops inserted during the reification attempt if it fails. The
+    // pattern-rewrite invariant requires the IR to be unchanged on failure.
+    Operation *opBeforeReify = dimOp->getPrevNode();
+
+    // Erase any ops inserted between opBeforeReify and dimOp in reverse order
+    // to respect use-def chains within that range.
+    auto eraseInsertedOps = [&]() {
+      Block::iterator begin = opBeforeReify
+                                  ? std::next(opBeforeReify->getIterator())
+                                  : dimOp->getBlock()->begin();
+      Block::iterator it = dimOp->getIterator();
+      while (it != begin) {
+        --it;
+        rewriter.eraseOp(&*it);
+      }
+    };
+
     FailureOr<OpFoldResult> replacement = reifyDimOfResult(
         rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex);
-    if (failed(replacement))
-      return failure();
-    // Check if the OpFoldResult is empty (unreifiable dimension).
-    if (!replacement.value())
+    // An empty (or failed) OpFoldResult signals that this specific dimension
+    // cannot be reified. Some implementations materialize all dimensions at
+    // once (e.g. via reifyResultShapes) and may create ops for other dimensions
+    // before discovering that this dimension is not reifiable. Erase those
+    // stray ops before returning failure.
+    if (failed(replacement) || !replacement.value()) {
+      eraseInsertedOps();
       return failure();
+    }
     Value replacementVal = getValueOrCreateConstantIndexOp(
         rewriter, dimOp.getLoc(), replacement.value());
     rewriter.replaceOp(dimOp, replacementVal);
diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
index 624e0990a4bb3..f41312839c094 100644
--- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
@@ -178,3 +178,24 @@ func.func @test_unreifiable_dim_of_result_shape(%arg0 : tensor<?x?xf32>)
 //   CHECK-DAG:   %[[OP:.+]] = "test.unreifiable_dim_of_result_shape"(%[[ARG0]])
 //       CHECK:   %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]]
 //       CHECK:   return %[[D0]], %[[D1]]
+
+// -----
+
+// Regression test: verify that when reifyResultShapes creates ops for dim 0
+// but signals dim 1 is not reifiable (empty OpFoldResult), those stray ops are
+// erased before failure is returned. Without the fix, the stray tensor.dim op
+// on %arg0 would remain in the IR (caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS).
+func.func @test_unreifiable_result_shapes_no_stray_ops(%arg0 : tensor<?x?xf32>)
+    -> index {
+  %c1 = arith.constant 1 : index
+  %0 = "test.unreifiable_result_shapes"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+  %d1 = tensor.dim %0, %c1 : tensor<?x?xf32>
+  return %d1 : index
+}
+// CHECK-LABEL: func @test_unreifiable_result_shapes_no_stray_ops(
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xf32>)
+//       CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//       CHECK:   %[[OP:.+]] = "test.unreifiable_result_shapes"(%[[ARG0]])
+// CHECK-NOT:     tensor.dim %[[ARG0]]  // key: no stray dim on the input arg
+//       CHECK:   %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]]
+//       CHECK:   return %[[D1]]



More information about the Mlir-commits mailing list