[Mlir-commits] [mlir] Validate type consistency in reintepret cast sizes (PR #140032)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 15 02:49:33 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: None (yaniv217)

<details>
<summary>Changes</summary>

Ensure that when peforming a reinterpret cast, the expected size and the result size are of the same type. Emit an error if one of the dimensions has a static size and the corresponding dimension has a dynamic size in the other.

---
Full diff: https://github.com/llvm/llvm-project/pull/140032.diff


5 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+11-2) 
- (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+6-3) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+7-7) 
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (+18) 
- (modified) mlir/test/Dialect/MemRef/ops.mlir (+6-6) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a0237c18cf2fe..5a348b823d02b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1052,7 +1052,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
         }
       } // else dim.getIndex is a block argument to reshape->getBlock and
         // dominates reshape
-    } // Check condition 2
+    }   // Check condition 2
     else if (dim->getBlock() != reshape->getBlock() &&
              !dim.getIndex().getParentRegion()->isProperAncestor(
                  reshape->getParentRegion())) {
@@ -1835,6 +1835,15 @@ LogicalResult ReinterpretCastOp::verify() {
   // Match sizes in result memref type and in static_sizes attribute.
   for (auto [idx, resultSize, expectedSize] :
        llvm::enumerate(resultType.getShape(), getStaticSizes())) {
+    // Check that dynamic sizes are not mixed with static sizes
+    if (ShapedType::isDynamic(resultSize) &&
+        !ShapedType::isDynamic(expectedSize))
+      return emitError(
+          "expectedSize is static but received a dynamic resultSize ");
+    if (!ShapedType::isDynamic(resultSize) &&
+        ShapedType::isDynamic(expectedSize))
+      return emitError(
+          "expectedSize is dynamic but received a static resultSize ");
     if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
       return emitError("expected result type with size = ")
              << (ShapedType::isDynamic(expectedSize)
@@ -2008,7 +2017,7 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
       // Second, check the sizes.
       if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
                        op.getConstifiedMixedSizes()))
-          return false;
+        return false;
 
       // Finally, check the offset.
       assert(op.getMixedOffsets().size() == 1 &&
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 8906de9db3724..18b151c469da6 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -339,7 +339,8 @@ func.func @reinterpret_cast(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgr
 //       CHECK:  %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
 //       CHECK:  %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
 //       CHECK:  return %[[RET1]]
-  %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+  %c10 = arith.constant 10 : index
+  %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
   return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
 }
 
@@ -349,7 +350,8 @@ func.func @reinterpret_cast_0(%arg: memref<?xf32, #spirv.storage_class<CrossWork
 //   CHECK-DAG:  %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
 //   CHECK-DAG:  %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
 //       CHECK:  return %[[RET]]
-  %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+  %c10 = arith.constant 10 : index
+  %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
   return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
 }
 
@@ -361,7 +363,8 @@ func.func @reinterpret_cast_5(%arg: memref<?xf32, #spirv.storage_class<CrossWork
 //       CHECK:  %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
 //       CHECK:  %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
 //       CHECK:  return %[[RET1]]
-  %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+  %c10 = arith.constant 10 : index
+  %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
   return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
 }
 
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e7cee7cd85426..a53a5d10eceb5 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -923,13 +923,13 @@ func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref
 // same constant value, the match is valid.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_constants
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
+//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x2xf32,
 //       CHECK: return %[[CAST]]
-func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x2xf32, strided<[?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
   %c8 = arith.constant 8: index
-  %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
-  return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>>
+  %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x2xf32, strided<[?, ?], offset: ?>>
+  return %m2 : memref<?x2xf32, strided<[?, ?], offset: ?>>
 }
 // -----
 
@@ -954,10 +954,10 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
 //       CHECK: return %[[RES]]
-func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
+func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
-  %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
-  return %m2 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+  %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>>
+  return %m2 : memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>>
 }
 // -----
 
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 34fc4775924e7..c98d4913dc5d2 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -245,6 +245,24 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
 
 // -----
 
+func.func @memref_reinterpret_cast_static_dynamic_size_mismatch(%in: memref<1x?x2x1xf32>) {
+  // expected-error at +1 {{expectedSize is static but received a dynamic resultSize}}
+  %out = memref.reinterpret_cast %in to
+         offset: [0], sizes: [1, 4672, 1, 1], strides: [4672, 8, 8, 1]
+       : memref<1x?x2x1xf32> to memref<1x4672x?x1xf32> 
+}
+
+// -----
+
+func.func @memref_reinterpret_cast_dynamic_static_size_mismatch(%in: memref<1x?x2x1xf32>, %size: index) {
+  // expected-error at +1 {{expectedSize is dynamic but received a static resultSize}}
+  %out = memref.reinterpret_cast %in to
+         offset: [0], sizes: [1, %size, 1, 1], strides: [4672, 8, 8, 1]
+       : memref<1x?x2x1xf32> to memref<1x4672x2x1xf32>
+  return
+}
+
+// -----
 func.func @memref_reshape_element_type_mismatch(
        %buf: memref<*xf32>, %shape: memref<1xi32>) {
   // expected-error @+1 {{element types of source and destination memref types should be the same}}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 7038a6ff744e4..03e344e0e9cf2 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -131,20 +131,20 @@ func.func @memref_reinterpret_cast(%in: memref<?xf32>)
 
 // CHECK-LABEL: func @memref_reinterpret_cast_static_to_dynamic_sizes
 func.func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref<?xf32>)
-    -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+    -> memref<10x10xf32, strided<[?, 1], offset: ?>> {
   %out = memref.reinterpret_cast %in to
            offset: [1], sizes: [10, 10], strides: [1, 1]
-           : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
-  return %out : memref<10x?xf32, strided<[?, 1], offset: ?>>
+           : memref<?xf32> to memref<10x10xf32, strided<[?, 1], offset: ?>>
+  return %out : memref<10x10xf32, strided<[?, 1], offset: ?>>
 }
 
 // CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset
 func.func @memref_reinterpret_cast_dynamic_offset(%in: memref<?xf32>, %offset: index)
-    -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+    -> memref<10x10xf32, strided<[?, 1], offset: ?>> {
   %out = memref.reinterpret_cast %in to
            offset: [%offset], sizes: [10, 10], strides: [1, 1]
-           : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
-  return %out : memref<10x?xf32, strided<[?, 1], offset: ?>>
+           : memref<?xf32> to memref<10x10xf32, strided<[?, 1], offset: ?>>
+  return %out : memref<10x10xf32, strided<[?, 1], offset: ?>>
 }
 
 // CHECK-LABEL: func @memref_reshape(

``````````

</details>


https://github.com/llvm/llvm-project/pull/140032


More information about the Mlir-commits mailing list