[Mlir-commits] [mlir] [MLIR][MemRef] Fix DimOfReifyRankedShapedTypeOpInterface IR-change on failure (PR #188973)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 27 04:47:30 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
DimOfReifyRankedShapedTypeOpInterface::matchAndRewrite called reifyDimOfResult via the PatternRewriter. Some implementations delegate to the coarse-grained reifyResultShapes, which creates ops for ALL dimensions (e.g. a tensor.dim) before discovering that a specific dimension is not reifiable (signalled by an empty OpFoldResult).
The pattern then returned failure() once it saw the empty OpFoldResult, but the newly created ops were already in the IR. Under MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS this triggered "pattern returned failure but IR did change".
Fix: record the op immediately before the matched dim op, so we can identify ops inserted during the reification attempt. If reification returns an empty (unreifiable) OpFoldResult, erase those newly created ops before returning failure, restoring the IR to its original state.
Assisted-by: Claude Code
---
Full diff: https://github.com/llvm/llvm-project/pull/188973.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp (+27-3)
- (modified) mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir (+21)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index c498c8a60bf6e..2abf9fb1a58aa 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -90,13 +90,37 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
if (!dimIndex)
return failure();
+ // Save the insertion position so we can identify and erase any ops created
+ // during the reification attempt if reification fails; the pattern-rewrite
+ // invariant requires the IR to be unchanged on failure.
+ Operation *prevOp = dimOp->getPrevNode();
+
+ // Erase any ops inserted between prevOp and dimOp (exclusive) in reverse
+ // order to respect use-def chains within that range.
+ auto eraseInsertedOps = [&]() {
+ Block::iterator begin = prevOp ? std::next(prevOp->getIterator())
+ : dimOp->getBlock()->begin();
+ Block::iterator it = dimOp->getIterator();
+ while (it != begin) {
+ --it;
+ rewriter.eraseOp(&*it);
+ }
+ };
+
FailureOr<OpFoldResult> replacement = reifyDimOfResult(
rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex);
- if (failed(replacement))
+ if (failed(replacement)) {
+ eraseInsertedOps();
return failure();
- // Check if the OpFoldResult is empty (unreifiable dimension).
- if (!replacement.value())
+ }
+ // An empty OpFoldResult signals that this specific dimension cannot be
+ // reified. Some implementations materialise all dimensions at once (e.g.
+ // via the coarse-grained reifyResultShapes) and may create ops for other
+ // dimensions before discovering that this dimension is not reifiable.
+ if (!replacement.value()) {
+ eraseInsertedOps();
return failure();
+ }
Value replacementVal = getValueOrCreateConstantIndexOp(
rewriter, dimOp.getLoc(), replacement.value());
rewriter.replaceOp(dimOp, replacementVal);
diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
index 624e0990a4bb3..b268b81d1b439 100644
--- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
@@ -178,3 +178,24 @@ func.func @test_unreifiable_dim_of_result_shape(%arg0 : tensor<?x?xf32>)
// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_dim_of_result_shape"(%[[ARG0]])
// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]]
// CHECK: return %[[D0]], %[[D1]]
+
+// -----
+
+// Regression test: verify that when reifyResultShapes creates ops for dim 0
+// but signals dim 1 is not reifiable (empty OpFoldResult), those stray ops are
+// erased before failure is returned. Without the fix, the stray tensor.dim op
+// on %arg0 would remain in the IR (caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS).
+func.func @test_unreifiable_result_shapes_no_stray_ops(%arg0 : tensor<?x?xf32>)
+ -> index {
+ %c1 = arith.constant 1 : index
+ %0 = "test.unreifiable_result_shapes"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+ %d1 = tensor.dim %0, %c1 : tensor<?x?xf32>
+ return %d1 : index
+}
+// CHECK-LABEL: func @test_unreifiable_result_shapes_no_stray_ops(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>)
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[OP:.+]] = "test.unreifiable_result_shapes"(%[[ARG0]])
+// CHECK-NOT: tensor.dim %[[ARG0]]
+// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]]
+// CHECK: return %[[D1]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/188973
More information about the Mlir-commits
mailing list