[Mlir-commits] [mlir] Validate type consistency in reintepret cast sizes (PR #140032)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 27 23:48:53 PDT 2025


https://github.com/yaniv217 updated https://github.com/llvm/llvm-project/pull/140032

>From 01fa7a292177dbbd81de9ca2fdc01fa0cf0839d9 Mon Sep 17 00:00:00 2001
From: Yaniv Kaniel <yaniv.kaniel at gmail.com>
Date: Thu, 15 May 2025 12:33:48 +0300
Subject: [PATCH 1/2] Validate type consistency in reintepret cast sizes

Ensure that when peforming a reinterpret cast, the expected size and
the result size are of the same type. Emit an error if one of the
dimensions has a static size and the corresponding dimension has a
dynamic size in the other.
---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp       | 13 +++++++++++--
 .../MemRefToSPIRV/memref-to-spirv.mlir         |  9 ++++++---
 mlir/test/Dialect/MemRef/canonicalize.mlir     | 14 +++++++-------
 mlir/test/Dialect/MemRef/invalid.mlir          | 18 ++++++++++++++++++
 mlir/test/Dialect/MemRef/ops.mlir              | 12 ++++++------
 5 files changed, 48 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a0237c18cf2fe..5a348b823d02b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1052,7 +1052,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
         }
       } // else dim.getIndex is a block argument to reshape->getBlock and
         // dominates reshape
-    } // Check condition 2
+    }   // Check condition 2
     else if (dim->getBlock() != reshape->getBlock() &&
              !dim.getIndex().getParentRegion()->isProperAncestor(
                  reshape->getParentRegion())) {
@@ -1835,6 +1835,15 @@ LogicalResult ReinterpretCastOp::verify() {
   // Match sizes in result memref type and in static_sizes attribute.
   for (auto [idx, resultSize, expectedSize] :
        llvm::enumerate(resultType.getShape(), getStaticSizes())) {
+    // Check that dynamic sizes are not mixed with static sizes
+    if (ShapedType::isDynamic(resultSize) &&
+        !ShapedType::isDynamic(expectedSize))
+      return emitError(
+          "expectedSize is static but received a dynamic resultSize ");
+    if (!ShapedType::isDynamic(resultSize) &&
+        ShapedType::isDynamic(expectedSize))
+      return emitError(
+          "expectedSize is dynamic but received a static resultSize ");
     if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
       return emitError("expected result type with size = ")
              << (ShapedType::isDynamic(expectedSize)
@@ -2008,7 +2017,7 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
       // Second, check the sizes.
       if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
                        op.getConstifiedMixedSizes()))
-          return false;
+        return false;
 
       // Finally, check the offset.
       assert(op.getMixedOffsets().size() == 1 &&
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 8906de9db3724..18b151c469da6 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -339,7 +339,8 @@ func.func @reinterpret_cast(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgr
 //       CHECK:  %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
 //       CHECK:  %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
 //       CHECK:  return %[[RET1]]
-  %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+  %c10 = arith.constant 10 : index
+  %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
   return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
 }
 
@@ -349,7 +350,8 @@ func.func @reinterpret_cast_0(%arg: memref<?xf32, #spirv.storage_class<CrossWork
 //   CHECK-DAG:  %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
 //   CHECK-DAG:  %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
 //       CHECK:  return %[[RET]]
-  %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+  %c10 = arith.constant 10 : index
+  %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
   return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
 }
 
@@ -361,7 +363,8 @@ func.func @reinterpret_cast_5(%arg: memref<?xf32, #spirv.storage_class<CrossWork
 //       CHECK:  %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
 //       CHECK:  %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
 //       CHECK:  return %[[RET1]]
-  %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+  %c10 = arith.constant 10 : index
+  %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
   return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
 }
 
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e7cee7cd85426..a53a5d10eceb5 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -923,13 +923,13 @@ func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref
 // same constant value, the match is valid.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_constants
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
+//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x2xf32,
 //       CHECK: return %[[CAST]]
-func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x2xf32, strided<[?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
   %c8 = arith.constant 8: index
-  %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
-  return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>>
+  %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x2xf32, strided<[?, ?], offset: ?>>
+  return %m2 : memref<?x2xf32, strided<[?, ?], offset: ?>>
 }
 // -----
 
@@ -954,10 +954,10 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
 //       CHECK: return %[[RES]]
-func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
+func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
-  %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
-  return %m2 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+  %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>>
+  return %m2 : memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>>
 }
 // -----
 
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 34fc4775924e7..c98d4913dc5d2 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -245,6 +245,24 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
 
 // -----
 
+func.func @memref_reinterpret_cast_static_dynamic_size_mismatch(%in: memref<1x?x2x1xf32>) {
+  // expected-error at +1 {{expectedSize is static but received a dynamic resultSize}}
+  %out = memref.reinterpret_cast %in to
+         offset: [0], sizes: [1, 4672, 1, 1], strides: [4672, 8, 8, 1]
+       : memref<1x?x2x1xf32> to memref<1x4672x?x1xf32> 
+}
+
+// -----
+
+func.func @memref_reinterpret_cast_dynamic_static_size_mismatch(%in: memref<1x?x2x1xf32>, %size: index) {
+  // expected-error at +1 {{expectedSize is dynamic but received a static resultSize}}
+  %out = memref.reinterpret_cast %in to
+         offset: [0], sizes: [1, %size, 1, 1], strides: [4672, 8, 8, 1]
+       : memref<1x?x2x1xf32> to memref<1x4672x2x1xf32>
+  return
+}
+
+// -----
 func.func @memref_reshape_element_type_mismatch(
        %buf: memref<*xf32>, %shape: memref<1xi32>) {
   // expected-error @+1 {{element types of source and destination memref types should be the same}}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 7038a6ff744e4..03e344e0e9cf2 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -131,20 +131,20 @@ func.func @memref_reinterpret_cast(%in: memref<?xf32>)
 
 // CHECK-LABEL: func @memref_reinterpret_cast_static_to_dynamic_sizes
 func.func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref<?xf32>)
-    -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+    -> memref<10x10xf32, strided<[?, 1], offset: ?>> {
   %out = memref.reinterpret_cast %in to
            offset: [1], sizes: [10, 10], strides: [1, 1]
-           : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
-  return %out : memref<10x?xf32, strided<[?, 1], offset: ?>>
+           : memref<?xf32> to memref<10x10xf32, strided<[?, 1], offset: ?>>
+  return %out : memref<10x10xf32, strided<[?, 1], offset: ?>>
 }
 
 // CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset
 func.func @memref_reinterpret_cast_dynamic_offset(%in: memref<?xf32>, %offset: index)
-    -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+    -> memref<10x10xf32, strided<[?, 1], offset: ?>> {
   %out = memref.reinterpret_cast %in to
            offset: [%offset], sizes: [10, 10], strides: [1, 1]
-           : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
-  return %out : memref<10x?xf32, strided<[?, 1], offset: ?>>
+           : memref<?xf32> to memref<10x10xf32, strided<[?, 1], offset: ?>>
+  return %out : memref<10x10xf32, strided<[?, 1], offset: ?>>
 }
 
 // CHECK-LABEL: func @memref_reshape(

>From c9814d98ee08699ce244f9bf473519dce9a3d7d3 Mon Sep 17 00:00:00 2001
From: Yaniv Kaniel <yaniv.kaniel at gmail.com>
Date: Wed, 21 May 2025 14:57:08 +0300
Subject: [PATCH 2/2] Validate type consistency in reintepret cast offsets

Ensure that when peforming a reinterpret cast, the expected offset and
the result offset are of the same type. Emit an error if one of the
dimensions has a static offset and the corresponding dimension has a
dynamic offset in the other. Delete previous test that is a specific
instance of this case.
---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 13 ++++++--
 .../MemRefToSPIRV/memref-to-spirv.mlir        | 16 ++++-----
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 12 +++----
 mlir/test/Dialect/MemRef/invalid.mlir         | 33 ++++++++++++-------
 mlir/test/Dialect/MemRef/ops.mlir             |  6 ++--
 5 files changed, 49 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5a348b823d02b..82fc4eac5b40b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1839,11 +1839,11 @@ LogicalResult ReinterpretCastOp::verify() {
     if (ShapedType::isDynamic(resultSize) &&
         !ShapedType::isDynamic(expectedSize))
       return emitError(
-          "expectedSize is static but received a dynamic resultSize ");
+          "expected size is static, but result type dimension is dynamic ");
     if (!ShapedType::isDynamic(resultSize) &&
         ShapedType::isDynamic(expectedSize))
       return emitError(
-          "expectedSize is dynamic but received a static resultSize ");
+          "expected size is dynamic, but result type dimension is static ");
     if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
       return emitError("expected result type with size = ")
              << (ShapedType::isDynamic(expectedSize)
@@ -1863,6 +1863,15 @@ LogicalResult ReinterpretCastOp::verify() {
 
   // Match offset in result memref type and in static_offsets attribute.
   int64_t expectedOffset = getStaticOffsets().front();
+  // Check that dynamic offset is not mixed with static offset
+  if (ShapedType::isDynamic(resultOffset) &&
+      !ShapedType::isDynamic(expectedOffset))
+    return emitError(
+        "expected offset is static, but result type offset is dynamic");
+  if (!ShapedType::isDynamic(resultOffset) &&
+      ShapedType::isDynamic(expectedOffset))
+    return emitError(
+        "expected offset is dynamic, but result type offset is static");
   if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
     return emitError("expected result type with offset = ")
            << (ShapedType::isDynamic(expectedOffset)
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 18b151c469da6..fbc1b8ca42377 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -346,26 +346,26 @@ func.func @reinterpret_cast(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgr
 
 // CHECK-LABEL: func.func @reinterpret_cast_0
 //  CHECK-SAME:  (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>)
-func.func @reinterpret_cast_0(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> {
+func.func @reinterpret_cast_0(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1]>, #spirv.storage_class<CrossWorkgroup>> {
 //   CHECK-DAG:  %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
-//   CHECK-DAG:  %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+//   CHECK-DAG:  %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1]>, #spirv.storage_class<CrossWorkgroup>>
 //       CHECK:  return %[[RET]]
   %c10 = arith.constant 10 : index
-  %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
-  return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+  %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: 0>, #spirv.storage_class<CrossWorkgroup>>
+  return %ret : memref<?xf32, strided<[1], offset: 0>, #spirv.storage_class<CrossWorkgroup>>
 }
 
 // CHECK-LABEL: func.func @reinterpret_cast_5
 //  CHECK-SAME:  (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>)
-func.func @reinterpret_cast_5(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> {
+func.func @reinterpret_cast_5(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1], offset: 5>, #spirv.storage_class<CrossWorkgroup>> {
 //       CHECK:  %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
 //       CHECK:  %[[OFF:.*]] = spirv.Constant 5 : i32
 //       CHECK:  %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
-//       CHECK:  %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+//       CHECK:  %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: 5>, #spirv.storage_class<CrossWorkgroup>>
 //       CHECK:  return %[[RET1]]
   %c10 = arith.constant 10 : index
-  %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
-  return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+  %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [%c10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: 5>, #spirv.storage_class<CrossWorkgroup>>
+  return %ret : memref<?xf32, strided<[1], offset: 5>, #spirv.storage_class<CrossWorkgroup>>
 }
 
 } // end module
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a53a5d10eceb5..7d293fcec0083 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -925,11 +925,11 @@ func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
 //       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x2xf32,
 //       CHECK: return %[[CAST]]
-func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x2xf32, strided<[?, ?], offset: ?>> {
+func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x2xf32, strided<[?, ?], offset: 0>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
   %c8 = arith.constant 8: index
-  %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x2xf32, strided<[?, ?], offset: ?>>
-  return %m2 : memref<?x2xf32, strided<[?, ?], offset: ?>>
+  %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x2xf32, strided<[?, ?], offset: 0>>
+  return %m2 : memref<?x2xf32, strided<[?, ?], offset: 0>>
 }
 // -----
 
@@ -970,10 +970,10 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
 //       CHECK: return %[[RES]]
-func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: 1>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
-  %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
-  return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>>
+  %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: 1>>
+  return %m2 : memref<?x?xf32, strided<[?, ?], offset: 1>>
 }
 
 // -----
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index c98d4913dc5d2..68d88e9214705 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -214,16 +214,6 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {
          : memref<?xf32> to memref<10xf32>
   return
 }
-
-// -----
-
-func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref<?xf32>, %offset : index) {
-  // expected-error @+1 {{expected result type with offset = dynamic instead of 0}}
-  %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
-         : memref<?xf32> to memref<10xf32>
-  return
-}
-
 // -----
 
 func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
@@ -246,7 +236,7 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
 // -----
 
 func.func @memref_reinterpret_cast_static_dynamic_size_mismatch(%in: memref<1x?x2x1xf32>) {
-  // expected-error at +1 {{expectedSize is static but received a dynamic resultSize}}
+  // expected-error at +1 {{expected size is static, but result type dimension is dynamic }}
   %out = memref.reinterpret_cast %in to
          offset: [0], sizes: [1, 4672, 1, 1], strides: [4672, 8, 8, 1]
        : memref<1x?x2x1xf32> to memref<1x4672x?x1xf32> 
@@ -255,13 +245,32 @@ func.func @memref_reinterpret_cast_static_dynamic_size_mismatch(%in: memref<1x?x
 // -----
 
 func.func @memref_reinterpret_cast_dynamic_static_size_mismatch(%in: memref<1x?x2x1xf32>, %size: index) {
-  // expected-error at +1 {{expectedSize is dynamic but received a static resultSize}}
+  // expected-error at +1 {{expected size is dynamic, but result type dimension is static }}
   %out = memref.reinterpret_cast %in to
          offset: [0], sizes: [1, %size, 1, 1], strides: [4672, 8, 8, 1]
        : memref<1x?x2x1xf32> to memref<1x4672x2x1xf32>
   return
 }
 
+// -----
+
+func.func @memref_reinterpret_cast_static_dynamic_offset_mismatch(%in: memref<?xf32>) {
+  // expected-error at +1 {{expected offset is static, but result type offset is dynamic}}
+  %out = memref.reinterpret_cast %in to
+         offset: [0], sizes: [10], strides: [1]
+       : memref<?xf32> to memref<10xf32, strided<[1], offset: ?>> 
+}
+
+// -----
+
+func.func @memref_reinterpret_cast_dynamic_static_offset_mismatch(%in: memref<?xf32>, %offset: index) {
+  // expected-error at +1 {{expected offset is dynamic, but result type offset is static}}
+  %out = memref.reinterpret_cast %in to
+         offset: [%offset], sizes: [10], strides: [1]
+       : memref<?xf32> to memref<10xf32, strided<[1], offset: 0>> 
+  return
+}
+
 // -----
 func.func @memref_reshape_element_type_mismatch(
        %buf: memref<*xf32>, %shape: memref<1xi32>) {
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 03e344e0e9cf2..0685334cd20ea 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -131,11 +131,11 @@ func.func @memref_reinterpret_cast(%in: memref<?xf32>)
 
 // CHECK-LABEL: func @memref_reinterpret_cast_static_to_dynamic_sizes
 func.func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref<?xf32>)
-    -> memref<10x10xf32, strided<[?, 1], offset: ?>> {
+    -> memref<10x10xf32, strided<[?, 1], offset: 1>> {
   %out = memref.reinterpret_cast %in to
            offset: [1], sizes: [10, 10], strides: [1, 1]
-           : memref<?xf32> to memref<10x10xf32, strided<[?, 1], offset: ?>>
-  return %out : memref<10x10xf32, strided<[?, 1], offset: ?>>
+           : memref<?xf32> to memref<10x10xf32, strided<[?, 1], offset: 1>>
+  return %out : memref<10x10xf32, strided<[?, 1], offset: 1>>
 }
 
 // CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset



More information about the Mlir-commits mailing list