[Mlir-commits] [mlir] Validate type consistency in reintepret cast sizes (PR #140032)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 15 02:49:33 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: None (yaniv217)
<details>
<summary>Changes</summary>
Ensure that when peforming a reinterpret cast, the expected size and the result size are of the same type. Emit an error if one of the dimensions has a static size and the corresponding dimension has a dynamic size in the other.
---
Full diff: https://github.com/llvm/llvm-project/pull/140032.diff
5 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+11-2)
- (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+6-3)
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+7-7)
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (+18)
- (modified) mlir/test/Dialect/MemRef/ops.mlir (+6-6)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a0237c18cf2fe..5a348b823d02b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1052,7 +1052,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
- } // Check condition 2
+ } // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
@@ -1835,6 +1835,15 @@ LogicalResult ReinterpretCastOp::verify() {
// Match sizes in result memref type and in static_sizes attribute.
for (auto [idx, resultSize, expectedSize] :
llvm::enumerate(resultType.getShape(), getStaticSizes())) {
+ // Check that dynamic sizes are not mixed with static sizes
+ if (ShapedType::isDynamic(resultSize) &&
+ !ShapedType::isDynamic(expectedSize))
+ return emitError(
+ "expectedSize is static but received a dynamic resultSize ");
+ if (!ShapedType::isDynamic(resultSize) &&
+ ShapedType::isDynamic(expectedSize))
+ return emitError(
+ "expectedSize is dynamic but received a static resultSize ");
if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
return emitError("expected result type with size = ")
<< (ShapedType::isDynamic(expectedSize)
@@ -2008,7 +2017,7 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
// Second, check the sizes.
if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
op.getConstifiedMixedSizes()))
- return false;
+ return false;
// Finally, check the offset.
assert(op.getMixedOffsets().size() == 1 &&
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 8906de9db3724..18b151c469da6 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -339,7 +339,8 @@ func.func @reinterpret_cast(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgr
// CHECK: %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
// CHECK: return %[[RET1]]
- %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+ %c10 = arith.constant 10 : index
+ %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
}
@@ -349,7 +350,8 @@ func.func @reinterpret_cast_0(%arg: memref<?xf32, #spirv.storage_class<CrossWork
// CHECK-DAG: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
// CHECK-DAG: %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
// CHECK: return %[[RET]]
- %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+ %c10 = arith.constant 10 : index
+ %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
}
@@ -361,7 +363,8 @@ func.func @reinterpret_cast_5(%arg: memref<?xf32, #spirv.storage_class<CrossWork
// CHECK: %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
// CHECK: return %[[RET1]]
- %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+ %c10 = arith.constant 10 : index
+ %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e7cee7cd85426..a53a5d10eceb5 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -923,13 +923,13 @@ func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref
// same constant value, the match is valid.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_constants
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
+// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x2xf32,
// CHECK: return %[[CAST]]
-func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x2xf32, strided<[?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
%c8 = arith.constant 8: index
- %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
- return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>>
+ %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x2xf32, strided<[?, ?], offset: ?>>
+ return %m2 : memref<?x2xf32, strided<[?, ?], offset: ?>>
}
// -----
@@ -954,10 +954,10 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
// CHECK: return %[[RES]]
-func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
+func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
- %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
- return %m2 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+ %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>>
+ return %m2 : memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>>
}
// -----
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 34fc4775924e7..c98d4913dc5d2 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -245,6 +245,24 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
// -----
+func.func @memref_reinterpret_cast_static_dynamic_size_mismatch(%in: memref<1x?x2x1xf32>) {
+ // expected-error at +1 {{expectedSize is static but received a dynamic resultSize}}
+ %out = memref.reinterpret_cast %in to
+ offset: [0], sizes: [1, 4672, 1, 1], strides: [4672, 8, 8, 1]
+ : memref<1x?x2x1xf32> to memref<1x4672x?x1xf32>
+}
+
+// -----
+
+func.func @memref_reinterpret_cast_dynamic_static_size_mismatch(%in: memref<1x?x2x1xf32>, %size: index) {
+ // expected-error at +1 {{expectedSize is dynamic but received a static resultSize}}
+ %out = memref.reinterpret_cast %in to
+ offset: [0], sizes: [1, %size, 1, 1], strides: [4672, 8, 8, 1]
+ : memref<1x?x2x1xf32> to memref<1x4672x2x1xf32>
+ return
+}
+
+// -----
func.func @memref_reshape_element_type_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{element types of source and destination memref types should be the same}}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 7038a6ff744e4..03e344e0e9cf2 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -131,20 +131,20 @@ func.func @memref_reinterpret_cast(%in: memref<?xf32>)
// CHECK-LABEL: func @memref_reinterpret_cast_static_to_dynamic_sizes
func.func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref<?xf32>)
- -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+ -> memref<10x10xf32, strided<[?, 1], offset: ?>> {
%out = memref.reinterpret_cast %in to
offset: [1], sizes: [10, 10], strides: [1, 1]
- : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
- return %out : memref<10x?xf32, strided<[?, 1], offset: ?>>
+ : memref<?xf32> to memref<10x10xf32, strided<[?, 1], offset: ?>>
+ return %out : memref<10x10xf32, strided<[?, 1], offset: ?>>
}
// CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset
func.func @memref_reinterpret_cast_dynamic_offset(%in: memref<?xf32>, %offset: index)
- -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+ -> memref<10x10xf32, strided<[?, 1], offset: ?>> {
%out = memref.reinterpret_cast %in to
offset: [%offset], sizes: [10, 10], strides: [1, 1]
- : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
- return %out : memref<10x?xf32, strided<[?, 1], offset: ?>>
+ : memref<?xf32> to memref<10x10xf32, strided<[?, 1], offset: ?>>
+ return %out : memref<10x10xf32, strided<[?, 1], offset: ?>>
}
// CHECK-LABEL: func @memref_reshape(
``````````
</details>
https://github.com/llvm/llvm-project/pull/140032
More information about the Mlir-commits
mailing list