[Mlir-commits] [mlir] [MLIR][Bufferization] Fix foldMemRefCasts dropping ranked return type for unranked->ranked cast (PR #189249)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Mar 29 07:28:26 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
When one-shot-bufferize with bufferize-function-boundaries is used and a function returns a ranked tensor that is produced by casting from an unranked intermediate (e.g. a call to a function returning tensor<*xf32>), the foldMemRefCasts post-processing step incorrectly unpacked the memref.cast from unranked to ranked memref, downgrading the function return type to the unranked memref type and using the unranked value as the return operand.
The fix is in unpackCast(): do not unpack a cast whose source is an unranked memref and whose result is a ranked memref, since doing so would lose type specificity.
Fixes https://github.com/llvm/llvm-project/issues/176739
Assisted-by: Claude Code
---
Full diff: https://github.com/llvm/llvm-project/pull/189249.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+8)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+21)
``````````diff
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index d29150a7403f9..6c5719ce6df8e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -378,10 +378,18 @@ static LogicalResult getFuncOpsOrderedByCalls(
/// Helper function that extracts the source from a memref.cast. If the given
/// value is not a memref.cast result, simply returns the given value.
+/// Only unpacks casts where the source is at least as specific as the result
+/// (i.e., does not unpack casts from unranked to ranked memref, which would
+/// downgrade the type).
static Value unpackCast(Value v) {
auto castOp = v.getDefiningOp<memref::CastOp>();
if (!castOp)
return v;
+ // Do not unpack a cast from unranked to ranked memref: folding would
+ // downgrade the function return type from ranked to unranked.
+ if (isa<UnrankedMemRefType>(castOp.getSource().getType()) &&
+ isa<MemRefType>(v.getType()))
+ return v;
return castOp.getSource();
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 8db1ebb87a1e5..d5cb7a0f14f5a 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -884,3 +884,24 @@ func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>)
// CHECK: return %[[out]]
return %out : !test.test_tensor<[4, 8], f64>
}
+
+// -----
+
+// Test that foldMemRefCasts does not downgrade a ranked return type to unranked
+// when the return value is produced by a memref.cast from unranked to ranked.
+// CHECK-LABEL: func.func @ranked_return_via_unranked_call(
+// CHECK-SAME: %[[arg:.*]]: memref<64x20x40xf32
+// CHECK-SAME: ) -> memref<64x20x40xf32
+func.func @ranked_return_via_unranked_call(%arg0: tensor<64x20x40xf32>) -> tensor<64x20x40xf32> {
+ // CHECK: %[[cast:.*]] = memref.cast %[[arg]]
+ // CHECK-SAME: to memref<*xf32>
+ %u = tensor.cast %arg0 : tensor<64x20x40xf32> to tensor<*xf32>
+ // CHECK: %[[call:.*]] = call @relu_unranked(%[[cast]])
+ %r = call @relu_unranked(%u) : (tensor<*xf32>) -> tensor<*xf32>
+ // CHECK: %[[cast2:.*]] = memref.cast %[[call]]
+ // CHECK-SAME: to memref<64x20x40xf32
+ %b = tensor.cast %r : tensor<*xf32> to tensor<64x20x40xf32>
+ // CHECK: return %[[cast2]]
+ return %b : tensor<64x20x40xf32>
+}
+func.func private @relu_unranked(tensor<*xf32>) -> tensor<*xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/189249
More information about the Mlir-commits
mailing list