[Mlir-commits] [mlir] [mlir][memref] Make memref.cast areCastCompatible return true when meet same types (PR #192029)
lonely eagle
llvmlistbot at llvm.org
Tue Apr 14 03:11:58 PDT 2026
https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/192029
>From 576aa870806c52860acebb71f9d85d26fb5bfa19 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 14 Apr 2026 09:10:45 +0000
Subject: [PATCH 1/2] make memref cast reCastCompatible return true when meet
same types.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 ++
mlir/test/Dialect/MemRef/invalid.mlir | 17 ++++++++++++++---
2 files changed, 16 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 27c1649ee4ed3..31e4640499276 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -737,6 +737,8 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
+ if (inputs == outputs)
+ return true;
Type a = inputs.front(), b = outputs.front();
auto aT = llvm::dyn_cast<MemRefType>(a);
auto bT = llvm::dyn_cast<MemRefType>(b);
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index d3670fde08d81..2f061a1bb773e 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -894,12 +894,23 @@ func.func @invalid_memref_cast() {
// -----
-// unranked to unranked
+// unranked incompatible element types
func.func @invalid_memref_cast() {
%0 = memref.alloc() : memref<2x5xf32, 0>
%1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0>
- // expected-error at +1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32>' are cast incompatible}}
- %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xf32, 0>
+ // expected-error at +1 {{operand type 'memref<*xf32>' and result type 'memref<*xi32>' are cast incompatible}}
+ %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xi32, 0>
+ return
+}
+
+// -----
+
+// unranked incompatible memory space
+func.func @invalid_memref_cast() {
+ %0 = memref.alloc() : memref<2x5xf32, 0>
+ %1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0>
+ // expected-error at +1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32, 1>' are cast incompatible}}
+ %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xf32, 1>
return
}
>From b3623f660fbac8147423138352d0c4cf843b674e Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 14 Apr 2026 10:11:41 +0000
Subject: [PATCH 2/2] add vaild test for this.
---
mlir/test/Dialect/MemRef/ops.mlir | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 14ac6a03d6ae0..27ab293856c89 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -235,6 +235,9 @@ func.func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memr
// CHECK: memref.cast %{{.*}} : memref<4x?x8xf32, strided<[32, 8, 1]>> to memref<4x1x8xf32, strided<[32, 16, 1]>>
%7 = memref.cast %arg4 : memref<4x?x8xf32, strided<[32, 8, 1]>> to memref<4x1x8xf32, strided<[32, 16, 1]>>
+
+ // CHECK: memref.cast %{{.*}} : memref<*xf32> to memref<*xf32>
+ %8 = memref.cast %4 : memref<*xf32> to memref<*xf32>
return
}
More information about the Mlir-commits
mailing list