[Mlir-commits] [mlir] [MLIR][Bufferization] Fix foldMemRefCasts dropping ranked return type for unranked->ranked cast (PR #189249)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 29 07:28:26 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

When one-shot-bufferize with bufferize-function-boundaries is used and a function returns a ranked tensor that is produced by casting from an unranked intermediate (e.g. a call to a function returning tensor<*xf32>), the foldMemRefCasts post-processing step incorrectly unpacked the memref.cast from unranked to ranked memref, downgrading the function return type to the unranked memref type and using the unranked value as the return operand.

The fix is in unpackCast(): do not unpack a cast whose source is an unranked memref and whose result is a ranked memref, since doing so would lose type specificity.

Fixes https://github.com/llvm/llvm-project/issues/176739

Assisted-by: Claude Code

---
Full diff: https://github.com/llvm/llvm-project/pull/189249.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+8) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+21) 


``````````diff
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index d29150a7403f9..6c5719ce6df8e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -378,10 +378,18 @@ static LogicalResult getFuncOpsOrderedByCalls(
 
 /// Helper function that extracts the source from a memref.cast. If the given
 /// value is not a memref.cast result, simply returns the given value.
+/// Only unpacks casts where the source is at least as specific as the result
+/// (i.e., does not unpack casts from unranked to ranked memref, which would
+/// downgrade the type).
 static Value unpackCast(Value v) {
   auto castOp = v.getDefiningOp<memref::CastOp>();
   if (!castOp)
     return v;
+  // Do not unpack a cast from unranked to ranked memref: folding would
+  // downgrade the function return type from ranked to unranked.
+  if (isa<UnrankedMemRefType>(castOp.getSource().getType()) &&
+      isa<MemRefType>(v.getType()))
+    return v;
   return castOp.getSource();
 }
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 8db1ebb87a1e5..d5cb7a0f14f5a 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -884,3 +884,24 @@ func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>)
   // CHECK: return %[[out]]
   return %out : !test.test_tensor<[4, 8], f64>
 }
+
+// -----
+
+// Test that foldMemRefCasts does not downgrade a ranked return type to unranked
+// when the return value is produced by a memref.cast from unranked to ranked.
+// CHECK-LABEL: func.func @ranked_return_via_unranked_call(
+// CHECK-SAME:      %[[arg:.*]]: memref<64x20x40xf32
+// CHECK-SAME:  ) -> memref<64x20x40xf32
+func.func @ranked_return_via_unranked_call(%arg0: tensor<64x20x40xf32>) -> tensor<64x20x40xf32> {
+  // CHECK: %[[cast:.*]] = memref.cast %[[arg]]
+  // CHECK-SAME: to memref<*xf32>
+  %u = tensor.cast %arg0 : tensor<64x20x40xf32> to tensor<*xf32>
+  // CHECK: %[[call:.*]] = call @relu_unranked(%[[cast]])
+  %r = call @relu_unranked(%u) : (tensor<*xf32>) -> tensor<*xf32>
+  // CHECK: %[[cast2:.*]] = memref.cast %[[call]]
+  // CHECK-SAME: to memref<64x20x40xf32
+  %b = tensor.cast %r : tensor<*xf32> to tensor<64x20x40xf32>
+  // CHECK: return %[[cast2]]
+  return %b : tensor<64x20x40xf32>
+}
+func.func private @relu_unranked(tensor<*xf32>) -> tensor<*xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/189249


More information about the Mlir-commits mailing list