[Mlir-commits] [mlir] 9fa05b0 - [mlir][memref] Make memref.cast areCastCompatible return true when meet same types (#192029)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 17 06:33:27 PDT 2026


Author: lonely eagle
Date: 2026-04-17T21:33:22+08:00
New Revision: 9fa05b0377de8d9f573b9a7f026bd202ef9e897a

URL: https://github.com/llvm/llvm-project/commit/9fa05b0377de8d9f573b9a7f026bd202ef9e897a
DIFF: https://github.com/llvm/llvm-project/commit/9fa05b0377de8d9f573b9a7f026bd202ef9e897a.diff

LOG: [mlir][memref] Make memref.cast areCastCompatible return true when meet same types (#192029)

When both the source and destination types of `memref.cast` are
unranked, it causes an IR verification failure, which impacts downstream
projects and its behavior is inconsistent with the documentation. To
address this, this PR now allows the operation to return true if the
source and destination types are identical.

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/Dialect/MemRef/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 27c1649ee4ed3..31e4640499276 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -737,6 +737,8 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (inputs.size() != 1 || outputs.size() != 1)
     return false;
+  if (inputs == outputs)
+    return true;
   Type a = inputs.front(), b = outputs.front();
   auto aT = llvm::dyn_cast<MemRefType>(a);
   auto bT = llvm::dyn_cast<MemRefType>(b);

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index d3670fde08d81..2f061a1bb773e 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -894,12 +894,23 @@ func.func @invalid_memref_cast() {
 
 // -----
 
-// unranked to unranked
+// unranked incompatible element types
 func.func @invalid_memref_cast() {
   %0 = memref.alloc() : memref<2x5xf32, 0>
   %1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0>
-  // expected-error at +1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32>' are cast incompatible}}
-  %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xf32, 0>
+  // expected-error at +1 {{operand type 'memref<*xf32>' and result type 'memref<*xi32>' are cast incompatible}}
+  %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xi32, 0>
+  return
+}
+
+// -----
+
+// unranked incompatible memory space
+func.func @invalid_memref_cast() {
+  %0 = memref.alloc() : memref<2x5xf32, 0>
+  %1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0>
+  // expected-error at +1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32, 1>' are cast incompatible}}
+  %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xf32, 1>
   return
 }
 

diff  --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 14ac6a03d6ae0..27ab293856c89 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -235,6 +235,9 @@ func.func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memr
 
   // CHECK: memref.cast %{{.*}} : memref<4x?x8xf32, strided<[32, 8, 1]>> to memref<4x1x8xf32, strided<[32, 16, 1]>>
   %7 = memref.cast %arg4 : memref<4x?x8xf32, strided<[32, 8, 1]>> to memref<4x1x8xf32, strided<[32, 16, 1]>>
+
+  // CHECK: memref.cast %{{.*}} : memref<*xf32> to memref<*xf32>
+  %8 = memref.cast %4 : memref<*xf32> to memref<*xf32>
   return
 }
 


        


More information about the Mlir-commits mailing list