[Mlir-commits] [mlir] [mlir][bufferization] Add XFAIL test for bufferize-function-boundaries returning unranked memref (PR #176746)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 6 03:37:05 PST 2026
https://github.com/tridhapuku updated https://github.com/llvm/llvm-project/pull/176746
>From 79c4aed042f614c149368adcbaae0cee1912f33b Mon Sep 17 00:00:00 2001
From: tridhapuku <abhidipu863 at gmail.com>
Date: Mon, 19 Jan 2026 12:26:18 +0000
Subject: [PATCH 1/2] [mlir][bufferization] Add XFAIL regression test for
one-shot bufferize dropping ranked return
---
...ze-func-boundary-ranked-from-unranked.mlir | 19 +++++++++++++++++++
1 file changed, 19 insertions(+)
create mode 100644 mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir
new file mode 100644
index 0000000000000..e1b025b7b5d0a
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries" | FileCheck %s
+// XFAIL: *
+
+module {
+ func.func @foo(%arg0: tensor<64x20x40xf32>) -> tensor<64x20x40xf32> {
+ %u = tensor.cast %arg0 : tensor<64x20x40xf32> to tensor<*xf32>
+ %r = call @relu(%u) : (tensor<*xf32>) -> tensor<*xf32>
+ %b = tensor.cast %r : tensor<*xf32> to tensor<64x20x40xf32>
+ return %b : tensor<64x20x40xf32>
+ }
+ func.func private @relu(tensor<*xf32>) -> tensor<*xf32>
+}
+
+// CHECK-LABEL: func.func @foo
+// CHECK-SAME: -> memref<64x20x40xf32
+// CHECK: %[[R:.*]] = call @relu
+// CHECK: %[[C:.*]] = memref.cast %[[R]] : memref<*xf32> to memref<64x20x40xf32
+// CHECK: return %[[C]] : memref<64x20x40xf32
+
>From 34b17c42209d1a206014346e3422847f9ada7eec Mon Sep 17 00:00:00 2001
From: tridhapuku <abhidipu863 at gmail.com>
Date: Fri, 6 Mar 2026 11:35:48 +0000
Subject: [PATCH 2/2] [mlir][bufferization] Prevent one-shot bufferize from
dropping return type precision info
---
.../Transforms/OneShotModuleBufferize.cpp | 141 ++++++++++++++++--
...ze-func-boundary-ranked-from-unranked.mlir | 1 -
...ufferize-return-memref-cast-precision.mlir | 57 +++++++
3 files changed, 186 insertions(+), 13 deletions(-)
create mode 100644 mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-return-memref-cast-precision.mlir
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index d29150a7403f9..f101f918c56e8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -376,13 +376,118 @@ static LogicalResult getFuncOpsOrderedByCalls(
return success();
}
-/// Helper function that extracts the source from a memref.cast. If the given
-/// value is not a memref.cast result, simply returns the given value.
-static Value unpackCast(Value v) {
- auto castOp = v.getDefiningOp<memref::CastOp>();
- if (!castOp)
+// this code will decide whether removing a direct memref.cast (by returning the cast source)
+// is guaranteed to NOT lose useful type information for return-type tightening.
+//
+// Goal of foldMemRefCasts:
+// - drop "signature accommodation" casts (usually layout/generalization casts)
+// - keep casts that encode strictly better info at the result (rank/dims/layout)
+//
+// "Safe to drop" means:
+// - returning the cast source instead of the cast result will not drop
+// (a) rank information
+// (b) static dimension information
+// (c) layout information (when cast result layout is more specific)
+//
+// Notes:
+// This code will calculate whether stripping a defining memref.cast loses type precision
+// that foldMemRefCasts tries to preserve for function results.
+// canSafelyDropMemrefCast(castOp) == true means dest is not adding info --> safe to drop the cast and use the source directly.
+// It is safe to remove the memref.cast and use the cast’s source value directly
+static bool canSafelyDropMemrefCast(memref::CastOp castOp) {
+ Type srcTy = castOp.getSource().getType();
+ Type dstTy = castOp.getType();
+
+ auto srcU = dyn_cast<UnrankedMemRefType>(srcTy);
+ auto dstU = dyn_cast<UnrankedMemRefType>(dstTy);
+ auto srcR = dyn_cast<MemRefType>(srcTy);
+ auto dstR = dyn_cast<MemRefType>(dstTy);
+
+ //src and dst must be memref types, and at least one of them must be ranked (otherwise, no precision to lose)
+ if (!srcU && !srcR)
+ return false;
+
+ //if dest is not memref type -->
+ if (!dstU && !dstR)
+ return false;
+
+ // Rank precision: do not drop unranked -> ranked.
+ if (srcU && dstR)
+ return false; //our case
+ if (srcR && dstU)
+ return true;
+ if (srcU && dstU)
+ return true;
+
+ // Both ranked from here.
+ if (srcR.getRank() != dstR.getRank())
+ return false;
+
+ // Shape precision: veto dynamic -> static in dst.
+ ArrayRef<int64_t> s = srcR.getShape();
+ ArrayRef<int64_t> d = dstR.getShape();
+ for (int64_t i = 0, e = srcR.getRank(); i < e; ++i)
+ if (ShapedType::isDynamic(s[i]) && !ShapedType::isDynamic(d[i]))
+ return false;
+
+ // Layout precision:
+ // - if dst is identity and src is not, dst is more specific: keep cast.
+ // - if both are strided, check in offset/strides.
+ // - otherwise (custom maps), stay conservative and keep cast.
+ auto sLayout = srcR.getLayout();
+ auto dLayout = dstR.getLayout();
+
+ auto sStrided = dyn_cast<StridedLayoutAttr>(sLayout);
+ auto dStrided = dyn_cast<StridedLayoutAttr>(dLayout);
+
+ bool sCustom = !sStrided && !sLayout.isIdentity();
+ bool dCustom = !dStrided && !dLayout.isIdentity();
+ if (sCustom || dCustom)
+ return sLayout.isIdentity() && dLayout.isIdentity();
+
+ if (dLayout.isIdentity() && !sLayout.isIdentity())
+ return false;
+
+ if (sStrided && dStrided) {
+ if (ShapedType::isDynamic(sStrided.getOffset()) &&
+ !ShapedType::isDynamic(dStrided.getOffset()))
+ return false;
+
+ ArrayRef<int64_t> ss = sStrided.getStrides();
+ ArrayRef<int64_t> ds = dStrided.getStrides();
+ if (ss.size() != ds.size())
+ return false;
+
+ for (size_t i = 0; i < ss.size(); ++i)
+ if (ShapedType::isDynamic(ss[i]) && !ShapedType::isDynamic(ds[i]))
+ return false;
+ }
+
+ return true;
+}
+
+// this code will return the defining memref.cast op for a value, or nullptr
+static memref::CastOp getDefiningMemRefCast(Value v) {
+ return v.getDefiningOp<memref::CastOp>();
+}
+
+// this code will return the value that should be used for return-type comparison
+// and for optional cast-stripping at func.return.
+//
+// Rule:
+// - when the value is not a memref.cast result, return the value
+// - when the value is a memref.cast result:
+// - if canSafelyDropMemrefCast(castOp) is true, return the cast source
+// (cast is layout-only / non-precision-gaining, safe to drop)
+// - else return the cast result
+// (cast is precision-gaining: unranked->ranked, dynamic->static, layout refinement, etc)
+static Value canonicalizeReturnValue(Value v) {
+ if (auto castOp = getDefiningMemRefCast(v)) {
+ if (canSafelyDropMemrefCast(castOp))
+ return castOp.getSource();
return v;
- return castOp.getSource();
+ }
+ return v;
}
/// Helper function that returns the return types (skipping casts) of the given
@@ -393,17 +498,19 @@ static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
assert(!returnOps.empty() && "expected at least one ReturnOp");
int numOperands = returnOps.front()->getNumOperands();
- // Helper function that unpacks memref.cast ops and returns the type.
- auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
+ // Helper function that conditionally drops memref.cast ops and returns the type.
+ auto getComparableType = [&](Value v) {
+ return canonicalizeReturnValue(v).getType();
+ };
SmallVector<Type> result;
for (int i = 0; i < numOperands; ++i) {
// Get the type of the i-th operand of the first func.return ops.
- Type t = getSourceType(returnOps.front()->getOperand(i));
+ Type t = getComparableType(returnOps.front()->getOperand(i));
// Check if all other func.return ops have a matching operand type.
for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
- if (getSourceType(returnOps[j]->getOperand(i)) != t)
+ if (getComparableType(returnOps[j]->getOperand(i)) != t)
t = Type();
result.push_back(t);
@@ -432,8 +539,18 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
for (func::ReturnOp returnOp : returnOps) {
for (OpOperand &operand : returnOp->getOpOperands()) {
// Bail if no common result type was found.
- if (resultTypes[operand.getOperandNumber()]) {
- operand.set(unpackCast(operand.get()));
+ int pos = operand.getOperandNumber();
+
+ // this code will skip rewriting when no common result type exists for pos
+ if (!resultTypes[pos])
+ continue;
+
+ Value v = operand.get();
+ if (auto castOp = getDefiningMemRefCast(v)) {
+ // this code will strip cast only when safe
+ if (canSafelyDropMemrefCast(castOp))
+ operand.set(castOp.getSource());
+ // else keep operand as-is (cast result)
}
}
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir
index e1b025b7b5d0a..31972e8d5d1a6 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-func-boundary-ranked-from-unranked.mlir
@@ -1,5 +1,4 @@
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries" | FileCheck %s
-// XFAIL: *
module {
func.func @foo(%arg0: tensor<64x20x40xf32>) -> tensor<64x20x40xf32> {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-return-memref-cast-precision.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-return-memref-cast-precision.mlir
new file mode 100644
index 0000000000000..d0430d28bd3dc
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-return-memref-cast-precision.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt -one-shot-bufferize="bufferize-function-boundaries" %s | FileCheck %s
+
+
+// Test that foldMemRefCasts does not drop a precision-gaining return cast from
+// unranked to ranked shape.
+func.func private @callee() -> tensor<*xf32>
+
+func.func @test_keep_ranked_return() -> tensor<8xf32> {
+ %0 = func.call @callee() : () -> tensor<*xf32>
+ %1 = tensor.cast %0 : tensor<*xf32> to tensor<8xf32>
+ return %1 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func.func @test_keep_ranked_return()
+// CHECK-SAME: -> memref<8xf32{{.*}}>
+// CHECK: %[[SRC:.*]] = call @callee() : () -> memref<*xf32>
+// CHECK: %[[CAST:.*]] = memref.cast %[[SRC]] : memref<*xf32> to memref<8xf32{{.*}}>
+// CHECK: return %[[CAST]] : memref<8xf32{{.*}}>
+// CHECK-NOT: return %[[SRC]] : memref<*xf32>
+
+// Test that foldMemRefCasts does not drop a precision-gaining return cast from
+// dynamic to static shape.
+func.func private @callee_dynamic_shape() -> tensor<?xf32>
+
+func.func @test_keep_static_shape_return() -> tensor<8xf32> {
+ %0 = func.call @callee_dynamic_shape() : () -> tensor<?xf32>
+ %1 = tensor.cast %0 : tensor<?xf32> to tensor<8xf32>
+ return %1 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func.func @test_keep_static_shape_return()
+// CHECK-SAME: -> memref<8xf32{{.*}}>
+// CHECK: %[[SRC:.*]] = call @callee_dynamic_shape() : () -> memref<?xf32{{.*}}>
+// CHECK: %[[CAST:.*]] = memref.cast %[[SRC]] : memref<?xf32{{.*}}> to memref<8xf32{{.*}}>
+// CHECK: return %[[CAST]] : memref<8xf32{{.*}}>
+// CHECK-NOT: return %[[SRC]] : memref<?xf32{{.*}}>
+
+// Test current bufferization behavior for a tensor with layout annotation.
+// This input does not currently produce a return memref.cast; the result stays
+// in the strided memref form.
+
+#src_layout = affine_map<(d0) -> (d0)>
+
+func.func private @callee_layout() -> tensor<8xf32, #src_layout>
+
+func.func @test_layout_return() -> tensor<8xf32> {
+ %0 = func.call @callee_layout() : () -> tensor<8xf32, #src_layout>
+ %1 = tensor.cast %0 : tensor<8xf32, #src_layout> to tensor<8xf32>
+ return %1 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func.func @test_layout_return()
+// CHECK-SAME: -> memref<8xf32, strided<[?], offset: ?>>
+/// CHECK: %[[SRC:.*]] = call @callee_layout() : () -> memref<8xf32, strided<[?], offset: ?>>
+/// CHECK-NOT: memref.cast
+/// CHECK: return %[[SRC]] : memref<8xf32, strided<[?], offset: ?>>
+
More information about the Mlir-commits
mailing list