[Mlir-commits] [mlir] [mlir][memref] Make memref.cast areCastCompatible return true when meet same types (PR #192029)

lonely eagle llvmlistbot at llvm.org
Tue Apr 14 02:19:02 PDT 2026


https://github.com/linuxlonelyeagle created https://github.com/llvm/llvm-project/pull/192029

When both the source and destination types of `memref.cast` are unranked, it causes an IR verification failure, which impacts downstream projects. To address this, this PR now allows the operation to return true if the source and destination types are identical.

>From 576aa870806c52860acebb71f9d85d26fb5bfa19 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 14 Apr 2026 09:10:45 +0000
Subject: [PATCH] make memref cast reCastCompatible return true when meet same
 types.

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp |  2 ++
 mlir/test/Dialect/MemRef/invalid.mlir    | 17 ++++++++++++++---
 2 files changed, 16 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 27c1649ee4ed3..31e4640499276 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -737,6 +737,8 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (inputs.size() != 1 || outputs.size() != 1)
     return false;
+  if (inputs == outputs)
+    return true;
   Type a = inputs.front(), b = outputs.front();
   auto aT = llvm::dyn_cast<MemRefType>(a);
   auto bT = llvm::dyn_cast<MemRefType>(b);
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index d3670fde08d81..2f061a1bb773e 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -894,12 +894,23 @@ func.func @invalid_memref_cast() {
 
 // -----
 
-// unranked to unranked
+// unranked incompatible element types
 func.func @invalid_memref_cast() {
   %0 = memref.alloc() : memref<2x5xf32, 0>
   %1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0>
-  // expected-error at +1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32>' are cast incompatible}}
-  %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xf32, 0>
+  // expected-error at +1 {{operand type 'memref<*xf32>' and result type 'memref<*xi32>' are cast incompatible}}
+  %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xi32, 0>
+  return
+}
+
+// -----
+
+// unranked incompatible memory space
+func.func @invalid_memref_cast() {
+  %0 = memref.alloc() : memref<2x5xf32, 0>
+  %1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0>
+  // expected-error at +1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32, 1>' are cast incompatible}}
+  %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xf32, 1>
   return
 }
 



More information about the Mlir-commits mailing list