[Mlir-commits] [mlir] Validate type consistency in reintepret cast sizes (PR #140032)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 27 23:48:53 PDT 2025
https://github.com/yaniv217 updated https://github.com/llvm/llvm-project/pull/140032
>From 01fa7a292177dbbd81de9ca2fdc01fa0cf0839d9 Mon Sep 17 00:00:00 2001
From: Yaniv Kaniel <yaniv.kaniel at gmail.com>
Date: Thu, 15 May 2025 12:33:48 +0300
Subject: [PATCH 1/2] Validate type consistency in reintepret cast sizes
Ensure that when peforming a reinterpret cast, the expected size and
the result size are of the same type. Emit an error if one of the
dimensions has a static size and the corresponding dimension has a
dynamic size in the other.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 13 +++++++++++--
.../MemRefToSPIRV/memref-to-spirv.mlir | 9 ++++++---
mlir/test/Dialect/MemRef/canonicalize.mlir | 14 +++++++-------
mlir/test/Dialect/MemRef/invalid.mlir | 18 ++++++++++++++++++
mlir/test/Dialect/MemRef/ops.mlir | 12 ++++++------
5 files changed, 48 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a0237c18cf2fe..5a348b823d02b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1052,7 +1052,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
- } // Check condition 2
+ } // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
@@ -1835,6 +1835,15 @@ LogicalResult ReinterpretCastOp::verify() {
// Match sizes in result memref type and in static_sizes attribute.
for (auto [idx, resultSize, expectedSize] :
llvm::enumerate(resultType.getShape(), getStaticSizes())) {
+ // Check that dynamic sizes are not mixed with static sizes
+ if (ShapedType::isDynamic(resultSize) &&
+ !ShapedType::isDynamic(expectedSize))
+ return emitError(
+ "expectedSize is static but received a dynamic resultSize ");
+ if (!ShapedType::isDynamic(resultSize) &&
+ ShapedType::isDynamic(expectedSize))
+ return emitError(
+ "expectedSize is dynamic but received a static resultSize ");
if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
return emitError("expected result type with size = ")
<< (ShapedType::isDynamic(expectedSize)
@@ -2008,7 +2017,7 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
// Second, check the sizes.
if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
op.getConstifiedMixedSizes()))
- return false;
+ return false;
// Finally, check the offset.
assert(op.getMixedOffsets().size() == 1 &&
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 8906de9db3724..18b151c469da6 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -339,7 +339,8 @@ func.func @reinterpret_cast(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgr
// CHECK: %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
// CHECK: return %[[RET1]]
- %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+ %c10 = arith.constant 10 : index
+ %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
}
@@ -349,7 +350,8 @@ func.func @reinterpret_cast_0(%arg: memref<?xf32, #spirv.storage_class<CrossWork
// CHECK-DAG: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
// CHECK-DAG: %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
// CHECK: return %[[RET]]
- %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+ %c10 = arith.constant 10 : index
+ %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
}
@@ -361,7 +363,8 @@ func.func @reinterpret_cast_5(%arg: memref<?xf32, #spirv.storage_class<CrossWork
// CHECK: %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
// CHECK: return %[[RET1]]
- %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+ %c10 = arith.constant 10 : index
+ %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e7cee7cd85426..a53a5d10eceb5 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -923,13 +923,13 @@ func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref
// same constant value, the match is valid.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_constants
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
+// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x2xf32,
// CHECK: return %[[CAST]]
-func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x2xf32, strided<[?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
%c8 = arith.constant 8: index
- %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
- return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>>
+ %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x2xf32, strided<[?, ?], offset: ?>>
+ return %m2 : memref<?x2xf32, strided<[?, ?], offset: ?>>
}
// -----
@@ -954,10 +954,10 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
// CHECK: return %[[RES]]
-func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
+func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
- %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
- return %m2 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+ %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>>
+ return %m2 : memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>>
}
// -----
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 34fc4775924e7..c98d4913dc5d2 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -245,6 +245,24 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
// -----
+func.func @memref_reinterpret_cast_static_dynamic_size_mismatch(%in: memref<1x?x2x1xf32>) {
+ // expected-error at +1 {{expectedSize is static but received a dynamic resultSize}}
+ %out = memref.reinterpret_cast %in to
+ offset: [0], sizes: [1, 4672, 1, 1], strides: [4672, 8, 8, 1]
+ : memref<1x?x2x1xf32> to memref<1x4672x?x1xf32>
+}
+
+// -----
+
+func.func @memref_reinterpret_cast_dynamic_static_size_mismatch(%in: memref<1x?x2x1xf32>, %size: index) {
+ // expected-error at +1 {{expectedSize is dynamic but received a static resultSize}}
+ %out = memref.reinterpret_cast %in to
+ offset: [0], sizes: [1, %size, 1, 1], strides: [4672, 8, 8, 1]
+ : memref<1x?x2x1xf32> to memref<1x4672x2x1xf32>
+ return
+}
+
+// -----
func.func @memref_reshape_element_type_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{element types of source and destination memref types should be the same}}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 7038a6ff744e4..03e344e0e9cf2 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -131,20 +131,20 @@ func.func @memref_reinterpret_cast(%in: memref<?xf32>)
// CHECK-LABEL: func @memref_reinterpret_cast_static_to_dynamic_sizes
func.func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref<?xf32>)
- -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+ -> memref<10x10xf32, strided<[?, 1], offset: ?>> {
%out = memref.reinterpret_cast %in to
offset: [1], sizes: [10, 10], strides: [1, 1]
- : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
- return %out : memref<10x?xf32, strided<[?, 1], offset: ?>>
+ : memref<?xf32> to memref<10x10xf32, strided<[?, 1], offset: ?>>
+ return %out : memref<10x10xf32, strided<[?, 1], offset: ?>>
}
// CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset
func.func @memref_reinterpret_cast_dynamic_offset(%in: memref<?xf32>, %offset: index)
- -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+ -> memref<10x10xf32, strided<[?, 1], offset: ?>> {
%out = memref.reinterpret_cast %in to
offset: [%offset], sizes: [10, 10], strides: [1, 1]
- : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
- return %out : memref<10x?xf32, strided<[?, 1], offset: ?>>
+ : memref<?xf32> to memref<10x10xf32, strided<[?, 1], offset: ?>>
+ return %out : memref<10x10xf32, strided<[?, 1], offset: ?>>
}
// CHECK-LABEL: func @memref_reshape(
>From c9814d98ee08699ce244f9bf473519dce9a3d7d3 Mon Sep 17 00:00:00 2001
From: Yaniv Kaniel <yaniv.kaniel at gmail.com>
Date: Wed, 21 May 2025 14:57:08 +0300
Subject: [PATCH 2/2] Validate type consistency in reintepret cast offsets
Ensure that when peforming a reinterpret cast, the expected offset and
the result offset are of the same type. Emit an error if one of the
dimensions has a static offset and the corresponding dimension has a
dynamic offset in the other. Delete previous test that is a specific
instance of this case.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 13 ++++++--
.../MemRefToSPIRV/memref-to-spirv.mlir | 16 ++++-----
mlir/test/Dialect/MemRef/canonicalize.mlir | 12 +++----
mlir/test/Dialect/MemRef/invalid.mlir | 33 ++++++++++++-------
mlir/test/Dialect/MemRef/ops.mlir | 6 ++--
5 files changed, 49 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5a348b823d02b..82fc4eac5b40b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1839,11 +1839,11 @@ LogicalResult ReinterpretCastOp::verify() {
if (ShapedType::isDynamic(resultSize) &&
!ShapedType::isDynamic(expectedSize))
return emitError(
- "expectedSize is static but received a dynamic resultSize ");
+ "expected size is static, but result type dimension is dynamic ");
if (!ShapedType::isDynamic(resultSize) &&
ShapedType::isDynamic(expectedSize))
return emitError(
- "expectedSize is dynamic but received a static resultSize ");
+ "expected size is dynamic, but result type dimension is static ");
if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
return emitError("expected result type with size = ")
<< (ShapedType::isDynamic(expectedSize)
@@ -1863,6 +1863,15 @@ LogicalResult ReinterpretCastOp::verify() {
// Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset = getStaticOffsets().front();
+ // Check that dynamic offset is not mixed with static offset
+ if (ShapedType::isDynamic(resultOffset) &&
+ !ShapedType::isDynamic(expectedOffset))
+ return emitError(
+ "expected offset is static, but result type offset is dynamic");
+ if (!ShapedType::isDynamic(resultOffset) &&
+ ShapedType::isDynamic(expectedOffset))
+ return emitError(
+ "expected offset is dynamic, but result type offset is static");
if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
return emitError("expected result type with offset = ")
<< (ShapedType::isDynamic(expectedOffset)
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 18b151c469da6..fbc1b8ca42377 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -346,26 +346,26 @@ func.func @reinterpret_cast(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgr
// CHECK-LABEL: func.func @reinterpret_cast_0
// CHECK-SAME: (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>)
-func.func @reinterpret_cast_0(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> {
+func.func @reinterpret_cast_0(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1]>, #spirv.storage_class<CrossWorkgroup>> {
// CHECK-DAG: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
-// CHECK-DAG: %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+// CHECK-DAG: %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1]>, #spirv.storage_class<CrossWorkgroup>>
// CHECK: return %[[RET]]
%c10 = arith.constant 10 : index
- %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
- return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+ %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: 0>, #spirv.storage_class<CrossWorkgroup>>
+ return %ret : memref<?xf32, strided<[1], offset: 0>, #spirv.storage_class<CrossWorkgroup>>
}
// CHECK-LABEL: func.func @reinterpret_cast_5
// CHECK-SAME: (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>)
-func.func @reinterpret_cast_5(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> {
+func.func @reinterpret_cast_5(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1], offset: 5>, #spirv.storage_class<CrossWorkgroup>> {
// CHECK: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
// CHECK: %[[OFF:.*]] = spirv.Constant 5 : i32
// CHECK: %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
-// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: 5>, #spirv.storage_class<CrossWorkgroup>>
// CHECK: return %[[RET1]]
%c10 = arith.constant 10 : index
- %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
- return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+ %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: 5>, #spirv.storage_class<CrossWorkgroup>>
+ return %ret : memref<?xf32, strided<[1], offset: 5>, #spirv.storage_class<CrossWorkgroup>>
}
} // end module
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a53a5d10eceb5..7d293fcec0083 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -925,11 +925,11 @@ func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x2xf32,
// CHECK: return %[[CAST]]
-func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x2xf32, strided<[?, ?], offset: ?>> {
+func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x2xf32, strided<[?, ?], offset: 0>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
%c8 = arith.constant 8: index
- %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x2xf32, strided<[?, ?], offset: ?>>
- return %m2 : memref<?x2xf32, strided<[?, ?], offset: ?>>
+ %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x2xf32, strided<[?, ?], offset: 0>>
+ return %m2 : memref<?x2xf32, strided<[?, ?], offset: 0>>
}
// -----
@@ -970,10 +970,10 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
// CHECK: return %[[RES]]
-func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: 1>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
- %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
- return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>>
+ %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: 1>>
+ return %m2 : memref<?x?xf32, strided<[?, ?], offset: 1>>
}
// -----
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index c98d4913dc5d2..68d88e9214705 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -214,16 +214,6 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {
: memref<?xf32> to memref<10xf32>
return
}
-
-// -----
-
-func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref<?xf32>, %offset : index) {
- // expected-error @+1 {{expected result type with offset = dynamic instead of 0}}
- %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
- : memref<?xf32> to memref<10xf32>
- return
-}
-
// -----
func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
@@ -246,7 +236,7 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
// -----
func.func @memref_reinterpret_cast_static_dynamic_size_mismatch(%in: memref<1x?x2x1xf32>) {
- // expected-error at +1 {{expectedSize is static but received a dynamic resultSize}}
+ // expected-error at +1 {{expected size is static, but result type dimension is dynamic }}
%out = memref.reinterpret_cast %in to
offset: [0], sizes: [1, 4672, 1, 1], strides: [4672, 8, 8, 1]
: memref<1x?x2x1xf32> to memref<1x4672x?x1xf32>
@@ -255,13 +245,32 @@ func.func @memref_reinterpret_cast_static_dynamic_size_mismatch(%in: memref<1x?x
// -----
func.func @memref_reinterpret_cast_dynamic_static_size_mismatch(%in: memref<1x?x2x1xf32>, %size: index) {
- // expected-error at +1 {{expectedSize is dynamic but received a static resultSize}}
+ // expected-error at +1 {{expected size is dynamic, but result type dimension is static }}
%out = memref.reinterpret_cast %in to
offset: [0], sizes: [1, %size, 1, 1], strides: [4672, 8, 8, 1]
: memref<1x?x2x1xf32> to memref<1x4672x2x1xf32>
return
}
+// -----
+
+func.func @memref_reinterpret_cast_static_dynamic_offset_mismatch(%in: memref<?xf32>) {
+ // expected-error at +1 {{expected offset is static, but result type offset is dynamic}}
+ %out = memref.reinterpret_cast %in to
+ offset: [0], sizes: [10], strides: [1]
+ : memref<?xf32> to memref<10xf32, strided<[1], offset: ?>>
+}
+
+// -----
+
+func.func @memref_reinterpret_cast_dynamic_static_offset_mismatch(%in: memref<?xf32>, %offset: index) {
+ // expected-error at +1 {{expected offset is dynamic, but result type offset is static}}
+ %out = memref.reinterpret_cast %in to
+ offset: [%offset], sizes: [10], strides: [1]
+ : memref<?xf32> to memref<10xf32, strided<[1], offset: 0>>
+ return
+}
+
// -----
func.func @memref_reshape_element_type_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 03e344e0e9cf2..0685334cd20ea 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -131,11 +131,11 @@ func.func @memref_reinterpret_cast(%in: memref<?xf32>)
// CHECK-LABEL: func @memref_reinterpret_cast_static_to_dynamic_sizes
func.func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref<?xf32>)
- -> memref<10x10xf32, strided<[?, 1], offset: ?>> {
+ -> memref<10x10xf32, strided<[?, 1], offset: 1>> {
%out = memref.reinterpret_cast %in to
offset: [1], sizes: [10, 10], strides: [1, 1]
- : memref<?xf32> to memref<10x10xf32, strided<[?, 1], offset: ?>>
- return %out : memref<10x10xf32, strided<[?, 1], offset: ?>>
+ : memref<?xf32> to memref<10x10xf32, strided<[?, 1], offset: 1>>
+ return %out : memref<10x10xf32, strided<[?, 1], offset: 1>>
}
// CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset
More information about the Mlir-commits
mailing list