[Mlir-commits] [mlir] [mlir][bufferization] Add XFAIL test for bufferize-function-boundaries returning unranked memref (PR #176746)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 6 03:37:05 PST 2026


https://github.com/tridhapuku updated https://github.com/llvm/llvm-project/pull/176746

>From 79c4aed042f614c149368adcbaae0cee1912f33b Mon Sep 17 00:00:00 2001
From: tridhapuku <abhidipu863 at gmail.com>
Date: Mon, 19 Jan 2026 12:26:18 +0000
Subject: [PATCH 1/2] [mlir][bufferization] Add XFAIL regression test for
 one-shot bufferize dropping ranked return

---
 ...ze-func-boundary-ranked-from-unranked.mlir | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)
 create mode 100644 mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir

diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir
new file mode 100644
index 0000000000000..e1b025b7b5d0a
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries" | FileCheck %s
+// XFAIL: *
+
+module {
+  func.func @foo(%arg0: tensor<64x20x40xf32>) -> tensor<64x20x40xf32> {
+    %u = tensor.cast %arg0 : tensor<64x20x40xf32> to tensor<*xf32>
+    %r = call @relu(%u) : (tensor<*xf32>) -> tensor<*xf32>
+    %b = tensor.cast %r : tensor<*xf32> to tensor<64x20x40xf32>
+    return %b : tensor<64x20x40xf32>
+  }
+  func.func private @relu(tensor<*xf32>) -> tensor<*xf32>
+}
+
+// CHECK-LABEL: func.func @foo
+// CHECK-SAME: -> memref<64x20x40xf32
+// CHECK: %[[R:.*]] = call @relu
+// CHECK: %[[C:.*]] = memref.cast %[[R]] : memref<*xf32> to memref<64x20x40xf32
+// CHECK: return %[[C]] : memref<64x20x40xf32
+

>From 34b17c42209d1a206014346e3422847f9ada7eec Mon Sep 17 00:00:00 2001
From: tridhapuku <abhidipu863 at gmail.com>
Date: Fri, 6 Mar 2026 11:35:48 +0000
Subject: [PATCH 2/2] [mlir][bufferization] Prevent one-shot bufferize from
 dropping return type precision info

---
 .../Transforms/OneShotModuleBufferize.cpp     | 141 ++++++++++++++++--
 ...ze-func-boundary-ranked-from-unranked.mlir |   1 -
 ...ufferize-return-memref-cast-precision.mlir |  57 +++++++
 3 files changed, 186 insertions(+), 13 deletions(-)
 create mode 100644 mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-return-memref-cast-precision.mlir

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index d29150a7403f9..f101f918c56e8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -376,13 +376,118 @@ static LogicalResult getFuncOpsOrderedByCalls(
   return success();
 }
 
-/// Helper function that extracts the source from a memref.cast. If the given
-/// value is not a memref.cast result, simply returns the given value.
-static Value unpackCast(Value v) {
-  auto castOp = v.getDefiningOp<memref::CastOp>();
-  if (!castOp)
+// this code will decide whether removing a direct memref.cast (by returning the cast source)
+// is guaranteed to NOT lose useful type information for return-type tightening.
+//
+// Goal of foldMemRefCasts:
+// - drop "signature accommodation" casts (usually layout/generalization casts)
+// - keep casts that encode strictly better info at the result (rank/dims/layout)
+//
+// "Safe to drop" means:
+// - returning the cast source instead of the cast result will not drop
+//   (a) rank information
+//   (b) static dimension information
+//   (c) layout information (when cast result layout is more specific)
+//
+// Notes:
+// This code will calculate whether stripping a defining memref.cast loses type precision
+// that foldMemRefCasts tries to preserve for function results.
+// canSafelyDropMemrefCast(castOp) == true means dest is not adding info --> safe to drop the cast and use the source directly.
+// It is safe to remove the memref.cast and use the cast’s source value directly
+static bool canSafelyDropMemrefCast(memref::CastOp castOp) {
+  Type srcTy = castOp.getSource().getType();
+  Type dstTy = castOp.getType();
+
+  auto srcU = dyn_cast<UnrankedMemRefType>(srcTy);
+  auto dstU = dyn_cast<UnrankedMemRefType>(dstTy);
+  auto srcR = dyn_cast<MemRefType>(srcTy);
+  auto dstR = dyn_cast<MemRefType>(dstTy);
+
+  //src and dst must be memref types, and at least one of them must be ranked (otherwise, no precision to lose)
+  if (!srcU && !srcR)
+    return false;
+
+  //if dest is not memref type --> 
+  if (!dstU && !dstR)
+      return false;
+
+  // Rank precision: do not drop unranked -> ranked.
+  if (srcU && dstR)
+    return false; //our case
+  if (srcR && dstU)
+    return true;
+  if (srcU && dstU)
+    return true;
+
+  // Both ranked from here.
+  if (srcR.getRank() != dstR.getRank())
+    return false;
+
+  // Shape precision: veto dynamic -> static in dst.
+  ArrayRef<int64_t> s = srcR.getShape();
+  ArrayRef<int64_t> d = dstR.getShape();
+  for (int64_t i = 0, e = srcR.getRank(); i < e; ++i)
+    if (ShapedType::isDynamic(s[i]) && !ShapedType::isDynamic(d[i]))
+      return false;
+
+  // Layout precision:
+  // - if dst is identity and src is not, dst is more specific: keep cast.
+  // - if both are strided, check in offset/strides.
+  // - otherwise (custom maps), stay conservative and keep cast.
+  auto sLayout = srcR.getLayout();
+  auto dLayout = dstR.getLayout();
+
+  auto sStrided = dyn_cast<StridedLayoutAttr>(sLayout);
+  auto dStrided = dyn_cast<StridedLayoutAttr>(dLayout);
+
+  bool sCustom = !sStrided && !sLayout.isIdentity();
+  bool dCustom = !dStrided && !dLayout.isIdentity();
+  if (sCustom || dCustom)
+    return sLayout.isIdentity() && dLayout.isIdentity();
+
+  if (dLayout.isIdentity() && !sLayout.isIdentity())
+    return false;
+
+  if (sStrided && dStrided) {
+    if (ShapedType::isDynamic(sStrided.getOffset()) &&
+        !ShapedType::isDynamic(dStrided.getOffset()))
+      return false;
+
+    ArrayRef<int64_t> ss = sStrided.getStrides();
+    ArrayRef<int64_t> ds = dStrided.getStrides();
+    if (ss.size() != ds.size())
+      return false;
+
+    for (size_t i = 0; i < ss.size(); ++i)
+      if (ShapedType::isDynamic(ss[i]) && !ShapedType::isDynamic(ds[i]))
+        return false;
+  }
+
+  return true;
+}
+
+// this code will return the defining memref.cast op for a value, or nullptr
+static memref::CastOp getDefiningMemRefCast(Value v) {
+  return v.getDefiningOp<memref::CastOp>();
+}
+
+// this code will return the value that should be used for return-type comparison
+// and for optional cast-stripping at func.return.
+//
+// Rule:
+// - when the value is not a memref.cast result, return the value
+// - when the value is a memref.cast result:
+//     - if canSafelyDropMemrefCast(castOp) is true, return the cast source
+//       (cast is layout-only / non-precision-gaining, safe to drop)
+//     - else return the cast result
+//       (cast is precision-gaining: unranked->ranked, dynamic->static, layout refinement, etc)
+static Value canonicalizeReturnValue(Value v) {
+  if (auto castOp = getDefiningMemRefCast(v)) {
+    if (canSafelyDropMemrefCast(castOp))
+      return castOp.getSource();
     return v;
-  return castOp.getSource();
+  }
+  return v;
 }
 
 /// Helper function that returns the return types (skipping casts) of the given
@@ -393,17 +498,19 @@ static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
   assert(!returnOps.empty() && "expected at least one ReturnOp");
   int numOperands = returnOps.front()->getNumOperands();
 
-  // Helper function that unpacks memref.cast ops and returns the type.
-  auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
+  // Helper function that conditionally drops memref.cast ops and returns the type.
+  auto getComparableType = [&](Value v) {
+    return canonicalizeReturnValue(v).getType();
+  };
 
   SmallVector<Type> result;
   for (int i = 0; i < numOperands; ++i) {
     // Get the type of the i-th operand of the first func.return ops.
-    Type t = getSourceType(returnOps.front()->getOperand(i));
+    Type t = getComparableType(returnOps.front()->getOperand(i));
 
     // Check if all other func.return ops have a matching operand type.
     for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
-      if (getSourceType(returnOps[j]->getOperand(i)) != t)
+      if (getComparableType(returnOps[j]->getOperand(i)) != t)
         t = Type();
 
     result.push_back(t);
@@ -432,8 +539,18 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
   for (func::ReturnOp returnOp : returnOps) {
     for (OpOperand &operand : returnOp->getOpOperands()) {
       // Bail if no common result type was found.
-      if (resultTypes[operand.getOperandNumber()]) {
-        operand.set(unpackCast(operand.get()));
+      int pos = operand.getOperandNumber();
+
+      // this code will skip rewriting when no common result type exists for pos
+      if (!resultTypes[pos])
+        continue;
+
+      Value v = operand.get();
+      if (auto castOp = getDefiningMemRefCast(v)) {
+        // this code will strip cast only when safe
+        if (canSafelyDropMemrefCast(castOp))
+          operand.set(castOp.getSource());
+        // else keep operand as-is (cast result)
       }
     }
   }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir
index e1b025b7b5d0a..31972e8d5d1a6 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir
@@ -1,5 +1,4 @@
 // RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries" | FileCheck %s
-// XFAIL: *
 
 module {
   func.func @foo(%arg0: tensor<64x20x40xf32>) -> tensor<64x20x40xf32> {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-return-memref-cast-precision.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-return-memref-cast-precision.mlir
new file mode 100644
index 0000000000000..d0430d28bd3dc
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-return-memref-cast-precision.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt -one-shot-bufferize="bufferize-function-boundaries" %s | FileCheck %s
+
+
+// Test that foldMemRefCasts does not drop a precision-gaining return cast from
+// unranked to ranked shape.
+func.func private @callee() -> tensor<*xf32>
+
+func.func @test_keep_ranked_return() -> tensor<8xf32> {
+  %0 = func.call @callee() : () -> tensor<*xf32>
+  %1 = tensor.cast %0 : tensor<*xf32> to tensor<8xf32>
+  return %1 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func.func @test_keep_ranked_return()
+// CHECK-SAME: -> memref<8xf32{{.*}}>
+// CHECK: %[[SRC:.*]] = call @callee() : () -> memref<*xf32>
+// CHECK: %[[CAST:.*]] = memref.cast %[[SRC]] : memref<*xf32> to memref<8xf32{{.*}}>
+// CHECK: return %[[CAST]] : memref<8xf32{{.*}}>
+// CHECK-NOT: return %[[SRC]] : memref<*xf32>
+
+// Test that foldMemRefCasts does not drop a precision-gaining return cast from
+// dynamic to static shape.
+func.func private @callee_dynamic_shape() -> tensor<?xf32>
+
+func.func @test_keep_static_shape_return() -> tensor<8xf32> {
+  %0 = func.call @callee_dynamic_shape() : () -> tensor<?xf32>
+  %1 = tensor.cast %0 : tensor<?xf32> to tensor<8xf32>
+  return %1 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func.func @test_keep_static_shape_return()
+// CHECK-SAME: -> memref<8xf32{{.*}}>
+// CHECK: %[[SRC:.*]] = call @callee_dynamic_shape() : () -> memref<?xf32{{.*}}>
+// CHECK: %[[CAST:.*]] = memref.cast %[[SRC]] : memref<?xf32{{.*}}> to memref<8xf32{{.*}}>
+// CHECK: return %[[CAST]] : memref<8xf32{{.*}}>
+// CHECK-NOT: return %[[SRC]] : memref<?xf32{{.*}}>
+
+// Test current bufferization behavior for a tensor with layout annotation.
+// This input does not currently produce a return memref.cast; the result stays
+// in the strided memref form.
+
+#src_layout = affine_map<(d0) -> (d0)>
+
+func.func private @callee_layout() -> tensor<8xf32, #src_layout>
+
+func.func @test_layout_return() -> tensor<8xf32> {
+  %0 = func.call @callee_layout() : () -> tensor<8xf32, #src_layout>
+  %1 = tensor.cast %0 : tensor<8xf32, #src_layout> to tensor<8xf32>
+  return %1 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func.func @test_layout_return()
+// CHECK-SAME: -> memref<8xf32, strided<[?], offset: ?>>
+/// CHECK: %[[SRC:.*]] = call @callee_layout() : () -> memref<8xf32, strided<[?], offset: ?>>
+/// CHECK-NOT: memref.cast
+/// CHECK: return %[[SRC]] : memref<8xf32, strided<[?], offset: ?>>
+



More information about the Mlir-commits mailing list