[Mlir-commits] [mlir] [mlir] update memref.cast cast compatible check (PR #179313)
ofri frishman
llvmlistbot at llvm.org
Mon Feb 2 11:56:27 PST 2026
https://github.com/ofri-frishman created https://github.com/llvm/llvm-project/pull/179313
Updating memref.cast check regarding if input and output are valid for casting.
Currently in case of casting between dynamic and static dims with different strides, the return value of the check is not symmetric and depends if casting for dynamic to static or vice versa. Updating the check logic to make this symmetric.
>From 6de84a0ddc5ae664e17d90192cae71a54097fcc0 Mon Sep 17 00:00:00 2001
From: Ofri Frishman <ofri4321 at gmail.com>
Date: Mon, 2 Feb 2026 21:49:57 +0200
Subject: [PATCH] [mlir] update memref.cast cast compatible check
Updating memref.cast check regarding if input and output are valid
for casting.
Currently in case of casting between dynamic and static dims with
different strides, the return value of the check is not symetric and
depends if casting for dynamic to static or vica versa.
Updating the check logic to make this symetric.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 4 ++--
mlir/test/Dialect/MemRef/ops.mlir | 8 +++++++-
2 files changed, 9 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 4ac8505c1223a..96e5ecd3bb23e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -766,7 +766,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (!checkCompatible(aOffset, bOffset))
return false;
for (const auto &[index, aStride] : enumerate(aStrides)) {
- if (aT.getDimSize(index) == 1)
+ if (aT.getDimSize(index) == 1 || bT.getDimSize(index) == 1)
continue;
if (!checkCompatible(aStride, bStrides[index]))
return false;
@@ -1128,7 +1128,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
- } // Check condition 2
+ } // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index cddc79f693b11..14ac6a03d6ae0 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -211,7 +211,7 @@ func.func @memref_alloca_scope() {
}
// CHECK-LABEL: func @memref_cast(%arg0
-func.func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memref<64x16x4xf32, strided<[64, 4, 1], offset: 0>>) {
+func.func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memref<64x16x4xf32, strided<[64, 4, 1], offset: 0>>, %arg3 : memref<4x1x8xf32, strided<[32, 16, 1]>>, %arg4 : memref<4x?x8xf32, strided<[32, 8, 1]>>) {
// CHECK: memref.cast %{{.*}} : memref<4xf32> to memref<?xf32>
%0 = memref.cast %arg0 : memref<4xf32> to memref<?xf32>
@@ -229,6 +229,12 @@ func.func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memr
// CHECK: memref.cast %{{.*}} : memref<*xf32> to memref<4xf32>
%5 = memref.cast %4 : memref<*xf32> to memref<4xf32>
+
+ // CHECK: memref.cast %{{.*}} : memref<4x1x8xf32, strided<[32, 16, 1]>> to memref<4x?x8xf32, strided<[32, 8, 1]>>
+ %6 = memref.cast %arg3 : memref<4x1x8xf32, strided<[32, 16, 1]>> to memref<4x?x8xf32, strided<[32, 8, 1]>>
+
+ // CHECK: memref.cast %{{.*}} : memref<4x?x8xf32, strided<[32, 8, 1]>> to memref<4x1x8xf32, strided<[32, 16, 1]>>
+ %7 = memref.cast %arg4 : memref<4x?x8xf32, strided<[32, 8, 1]>> to memref<4x1x8xf32, strided<[32, 16, 1]>>
return
}
More information about the Mlir-commits
mailing list