[Mlir-commits] [mlir] f7464cb - [MLIR][Bufferization] Fix foldMemRefCasts dropping ranked return type for unranked->ranked cast (#189249)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 15 06:10:06 PDT 2026
Author: Mehdi Amini
Date: 2026-04-15T15:10:00+02:00
New Revision: f7464cbdea73a337ee285e119023673249d5600e
URL: https://github.com/llvm/llvm-project/commit/f7464cbdea73a337ee285e119023673249d5600e
DIFF: https://github.com/llvm/llvm-project/commit/f7464cbdea73a337ee285e119023673249d5600e.diff
LOG: [MLIR][Bufferization] Fix foldMemRefCasts dropping ranked return type for unranked->ranked cast (#189249)
When one-shot-bufferize with bufferize-function-boundaries is used and a
function returns a ranked tensor that is produced by casting from an
unranked intermediate (e.g. a call to a function returning
tensor<*xf32>), the foldMemRefCasts post-processing step incorrectly
unpacked the memref.cast from unranked to ranked memref, downgrading the
function return type to the unranked memref type and using the unranked
value as the return operand.
The fix is in unpackCast(): do not unpack a cast whose source is an
unranked memref and whose result is a ranked memref, since doing so
would lose type specificity.
Fixes https://github.com/llvm/llvm-project/issues/176739
Assisted-by: Claude Code
Added:
Modified:
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index d29150a7403f9..6c5719ce6df8e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -378,10 +378,18 @@ static LogicalResult getFuncOpsOrderedByCalls(
/// Helper function that extracts the source from a memref.cast. If the given
/// value is not a memref.cast result, simply returns the given value.
+/// Only unpacks casts where the source is at least as specific as the result
+/// (i.e., does not unpack casts from unranked to ranked memref, which would
+/// downgrade the type).
static Value unpackCast(Value v) {
auto castOp = v.getDefiningOp<memref::CastOp>();
if (!castOp)
return v;
+ // Do not unpack a cast from unranked to ranked memref: folding would
+ // downgrade the function return type from ranked to unranked.
+ if (isa<UnrankedMemRefType>(castOp.getSource().getType()) &&
+ isa<MemRefType>(v.getType()))
+ return v;
return castOp.getSource();
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 8db1ebb87a1e5..d5cb7a0f14f5a 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -884,3 +884,24 @@ func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>)
// CHECK: return %[[out]]
return %out : !test.test_tensor<[4, 8], f64>
}
+
+// -----
+
+// Test that foldMemRefCasts does not downgrade a ranked return type to unranked
+// when the return value is produced by a memref.cast from unranked to ranked.
+// CHECK-LABEL: func.func @ranked_return_via_unranked_call(
+// CHECK-SAME: %[[arg:.*]]: memref<64x20x40xf32
+// CHECK-SAME: ) -> memref<64x20x40xf32
+func.func @ranked_return_via_unranked_call(%arg0: tensor<64x20x40xf32>) -> tensor<64x20x40xf32> {
+ // CHECK: %[[cast:.*]] = memref.cast %[[arg]]
+ // CHECK-SAME: to memref<*xf32>
+ %u = tensor.cast %arg0 : tensor<64x20x40xf32> to tensor<*xf32>
+ // CHECK: %[[call:.*]] = call @relu_unranked(%[[cast]])
+ %r = call @relu_unranked(%u) : (tensor<*xf32>) -> tensor<*xf32>
+ // CHECK: %[[cast2:.*]] = memref.cast %[[call]]
+ // CHECK-SAME: to memref<64x20x40xf32
+ %b = tensor.cast %r : tensor<*xf32> to tensor<64x20x40xf32>
+ // CHECK: return %[[cast2]]
+ return %b : tensor<64x20x40xf32>
+}
+func.func private @relu_unranked(tensor<*xf32>) -> tensor<*xf32>
More information about the Mlir-commits
mailing list