[Mlir-commits] [mlir] [MLIR][Bufferization] Fix foldMemRefCasts dropping ranked return type for unranked->ranked cast (PR #189249)

Mehdi Amini llvmlistbot at llvm.org
Sun Mar 29 07:27:55 PDT 2026


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/189249

When one-shot-bufferize with bufferize-function-boundaries is used and a function returns a ranked tensor that is produced by casting from an unranked intermediate (e.g. a call to a function returning tensor<*xf32>), the foldMemRefCasts post-processing step incorrectly unpacked the memref.cast from unranked to ranked memref, downgrading the function return type to the unranked memref type and using the unranked value as the return operand.

The fix is in unpackCast(): do not unpack a cast whose source is an unranked memref and whose result is a ranked memref, since doing so would lose type specificity.

Fixes https://github.com/llvm/llvm-project/issues/176739

Assisted-by: Claude Code

>From 51edc09f15e49907589b01c03f98a3de15e89935 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sun, 29 Mar 2026 06:43:34 -0700
Subject: [PATCH] [MLIR][Bufferization] Fix foldMemRefCasts dropping ranked
 return type for unranked->ranked cast

When one-shot-bufferize with bufferize-function-boundaries is used and a
function returns a ranked tensor that is produced by casting from an unranked
intermediate (e.g. a call to a function returning tensor<*xf32>), the
foldMemRefCasts post-processing step incorrectly unpacked the memref.cast
from unranked to ranked memref, downgrading the function return type to the
unranked memref type and using the unranked value as the return operand.

The fix is in unpackCast(): do not unpack a cast whose source is an unranked
memref and whose result is a ranked memref, since doing so would lose type
specificity.

Fixes https://github.com/llvm/llvm-project/issues/176739

Assisted-by: Claude Code
---
 .../Transforms/OneShotModuleBufferize.cpp     |  8 +++++++
 .../Transforms/one-shot-module-bufferize.mlir | 21 +++++++++++++++++++
 2 files changed, 29 insertions(+)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index d29150a7403f9..6c5719ce6df8e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -378,10 +378,18 @@ static LogicalResult getFuncOpsOrderedByCalls(
 
 /// Helper function that extracts the source from a memref.cast. If the given
 /// value is not a memref.cast result, simply returns the given value.
+/// Only unpacks casts where the source is at least as specific as the result
+/// (i.e., does not unpack casts from unranked to ranked memref, which would
+/// downgrade the type).
 static Value unpackCast(Value v) {
   auto castOp = v.getDefiningOp<memref::CastOp>();
   if (!castOp)
     return v;
+  // Do not unpack a cast from unranked to ranked memref: folding would
+  // downgrade the function return type from ranked to unranked.
+  if (isa<UnrankedMemRefType>(castOp.getSource().getType()) &&
+      isa<MemRefType>(v.getType()))
+    return v;
   return castOp.getSource();
 }
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 8db1ebb87a1e5..d5cb7a0f14f5a 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -884,3 +884,24 @@ func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>)
   // CHECK: return %[[out]]
   return %out : !test.test_tensor<[4, 8], f64>
 }
+
+// -----
+
+// Test that foldMemRefCasts does not downgrade a ranked return type to unranked
+// when the return value is produced by a memref.cast from unranked to ranked.
+// CHECK-LABEL: func.func @ranked_return_via_unranked_call(
+// CHECK-SAME:      %[[arg:.*]]: memref<64x20x40xf32
+// CHECK-SAME:  ) -> memref<64x20x40xf32
+func.func @ranked_return_via_unranked_call(%arg0: tensor<64x20x40xf32>) -> tensor<64x20x40xf32> {
+  // CHECK: %[[cast:.*]] = memref.cast %[[arg]]
+  // CHECK-SAME: to memref<*xf32>
+  %u = tensor.cast %arg0 : tensor<64x20x40xf32> to tensor<*xf32>
+  // CHECK: %[[call:.*]] = call @relu_unranked(%[[cast]])
+  %r = call @relu_unranked(%u) : (tensor<*xf32>) -> tensor<*xf32>
+  // CHECK: %[[cast2:.*]] = memref.cast %[[call]]
+  // CHECK-SAME: to memref<64x20x40xf32
+  %b = tensor.cast %r : tensor<*xf32> to tensor<64x20x40xf32>
+  // CHECK: return %[[cast2]]
+  return %b : tensor<64x20x40xf32>
+}
+func.func private @relu_unranked(tensor<*xf32>) -> tensor<*xf32>



More information about the Mlir-commits mailing list