[Mlir-commits] [mlir] [mlir][memref] Make memref.cast areCastCompatible return true when meet same types (PR #192029)

lonely eagle llvmlistbot at llvm.org
Tue Apr 14 03:11:58 PDT 2026


https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/192029

>From 576aa870806c52860acebb71f9d85d26fb5bfa19 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 14 Apr 2026 09:10:45 +0000
Subject: [PATCH 1/2] make memref cast reCastCompatible return true when meet
 same types.

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp |  2 ++
 mlir/test/Dialect/MemRef/invalid.mlir    | 17 ++++++++++++++---
 2 files changed, 16 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 27c1649ee4ed3..31e4640499276 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -737,6 +737,8 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (inputs.size() != 1 || outputs.size() != 1)
     return false;
+  if (inputs == outputs)
+    return true;
   Type a = inputs.front(), b = outputs.front();
   auto aT = llvm::dyn_cast<MemRefType>(a);
   auto bT = llvm::dyn_cast<MemRefType>(b);
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index d3670fde08d81..2f061a1bb773e 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -894,12 +894,23 @@ func.func @invalid_memref_cast() {
 
 // -----
 
-// unranked to unranked
+// unranked incompatible element types
 func.func @invalid_memref_cast() {
   %0 = memref.alloc() : memref<2x5xf32, 0>
   %1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0>
-  // expected-error at +1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32>' are cast incompatible}}
-  %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xf32, 0>
+  // expected-error at +1 {{operand type 'memref<*xf32>' and result type 'memref<*xi32>' are cast incompatible}}
+  %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xi32, 0>
+  return
+}
+
+// -----
+
+// unranked incompatible memory space
+func.func @invalid_memref_cast() {
+  %0 = memref.alloc() : memref<2x5xf32, 0>
+  %1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0>
+  // expected-error at +1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32, 1>' are cast incompatible}}
+  %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xf32, 1>
   return
 }
 

>From b3623f660fbac8147423138352d0c4cf843b674e Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 14 Apr 2026 10:11:41 +0000
Subject: [PATCH 2/2] add vaild test for this.

---
 mlir/test/Dialect/MemRef/ops.mlir | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 14ac6a03d6ae0..27ab293856c89 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -235,6 +235,9 @@ func.func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memr
 
   // CHECK: memref.cast %{{.*}} : memref<4x?x8xf32, strided<[32, 8, 1]>> to memref<4x1x8xf32, strided<[32, 16, 1]>>
   %7 = memref.cast %arg4 : memref<4x?x8xf32, strided<[32, 8, 1]>> to memref<4x1x8xf32, strided<[32, 16, 1]>>
+
+  // CHECK: memref.cast %{{.*}} : memref<*xf32> to memref<*xf32>
+  %8 = memref.cast %4 : memref<*xf32> to memref<*xf32>
   return
 }
 



More information about the Mlir-commits mailing list