[Mlir-commits] [mlir] 33cec20 - [mlir][memref] Tighten verification of memref.reinterpret_cast

Stephan Herhut llvmlistbot at llvm.org
Mon Jan 10 02:56:12 PST 2022


Author: Stephan Herhut
Date: 2022-01-10T11:55:47+01:00
New Revision: 33cec20dbd3bb8a896084da4a87f00d1cf13f77a

URL: https://github.com/llvm/llvm-project/commit/33cec20dbd3bb8a896084da4a87f00d1cf13f77a
DIFF: https://github.com/llvm/llvm-project/commit/33cec20dbd3bb8a896084da4a87f00d1cf13f77a.diff

LOG: [mlir][memref] Tighten verification of memref.reinterpret_cast

We allow the omission of a map in memref.reinterpret_cast under the assumption,
that the cast might cast to an identity layout. This change adds verification
that the static knowledge that is present in the reinterpret_cast supports
this assumption.

Differential Revision: https://reviews.llvm.org/D116601

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/Dialect/MemRef/ops.mlir
    mlir/test/mlir-cpu-runner/copy.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5dec3b717bcf0..2e7dc4112592b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1155,40 +1155,44 @@ static LogicalResult verify(ReinterpretCastOp op) {
                                  extractFromI64ArrayAttr(op.static_sizes())))) {
     int64_t resultSize = std::get<0>(en.value());
     int64_t expectedSize = std::get<1>(en.value());
-    if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
+    if (!ShapedType::isDynamic(resultSize) &&
+        !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
       return op.emitError("expected result type with size = ")
              << expectedSize << " instead of " << resultSize
              << " in dim = " << en.index();
   }
 
-  // Match offset and strides in static_offset and static_strides attributes if
-  // result memref type has an affine map specified.
-  if (!resultType.getLayout().isIdentity()) {
-    int64_t resultOffset;
-    SmallVector<int64_t, 4> resultStrides;
-    if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
-      return failure();
-
-    // Match offset in result memref type and in static_offsets attribute.
-    int64_t expectedOffset =
-        extractFromI64ArrayAttr(op.static_offsets()).front();
-    if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
-        resultOffset != expectedOffset)
-      return op.emitError("expected result type with offset = ")
-             << resultOffset << " instead of " << expectedOffset;
-
-    // Match strides in result memref type and in static_strides attribute.
-    for (auto &en : llvm::enumerate(llvm::zip(
-             resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
-      int64_t resultStride = std::get<0>(en.value());
-      int64_t expectedStride = std::get<1>(en.value());
-      if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
-          resultStride != expectedStride)
-        return op.emitError("expected result type with stride = ")
-               << expectedStride << " instead of " << resultStride
-               << " in dim = " << en.index();
-    }
+  // Match offset and strides in static_offset and static_strides attributes. If
+  // result memref type has no affine map specified, this will assume an
+  // identity layout.
+  int64_t resultOffset;
+  SmallVector<int64_t, 4> resultStrides;
+  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+    return op.emitError(
+               "expected result type to have strided layout but found ")
+           << resultType;
+
+  // Match offset in result memref type and in static_offsets attribute.
+  int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front();
+  if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
+      !ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
+      resultOffset != expectedOffset)
+    return op.emitError("expected result type with offset = ")
+           << resultOffset << " instead of " << expectedOffset;
+
+  // Match strides in result memref type and in static_strides attribute.
+  for (auto &en : llvm::enumerate(llvm::zip(
+           resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
+    int64_t resultStride = std::get<0>(en.value());
+    int64_t expectedStride = std::get<1>(en.value());
+    if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
+        !ShapedType::isDynamicStrideOrOffset(expectedStride) &&
+        resultStride != expectedStride)
+      return op.emitError("expected result type with stride = ")
+             << expectedStride << " instead of " << resultStride
+             << " in dim = " << en.index();
   }
+
   return success();
 }
 

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 2e81705049f54..6ddb49de9932e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -151,7 +151,7 @@ func @multiple_reducing_dims_all_dynamic(%arg0 : memref<?x?x?xf32, offset: ?, st
 //       CHECK:   return %[[SIZE]] : index
 func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index {
   %c0 = arith.constant 0 : index
-  %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [0] : memref<?xi8> to memref<?xi8>
+  %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [1] : memref<?xi8> to memref<?xi8>
   %1 = memref.dim %0, %c0 : memref<?xi8>
   return %1 : index
 }

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 90f851959748c..51dce9c8b20bb 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -208,6 +208,44 @@ func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
 
 // -----
 
+func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {
+  // expected-error @+1 {{expected result type with offset = 0 instead of 2}}
+  %out = memref.reinterpret_cast %in to offset: [2], sizes: [10], strides: [1]
+         : memref<?xf32> to memref<10xf32>
+  return
+}
+
+// -----
+
+func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
+  // expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}}
+  %out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]
+         : memref<?xf32> to memref<10xf32>
+  return
+}
+
+// -----
+
+func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
+  // expected-error @+1 {{expected result type with stride = 42 instead of 10 in dim = 0}}
+  %out = memref.reinterpret_cast %in to
+           offset: [0], sizes: [9, 10], strides: [42, 1]
+         : memref<?x?xf32> to memref<9x10xf32>
+  return
+}
+
+// -----
+
+func @memref_reinterpret_cast_non_strided_layout(%in: memref<?x?xf32>) {
+  // expected-error @+1 {{expected result type to have strided layout but found 'memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>}}
+  %out = memref.reinterpret_cast %in to
+           offset: [0], sizes: [9, 10], strides: [42, 1]
+         : memref<?x?xf32> to memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>
+  return
+}
+
+// -----
+
 func @memref_reshape_element_type_mismatch(
        %buf: memref<*xf32>, %shape: memref<1xi32>) {
   // expected-error @+1 {{element types of source and destination memref types should be the same}}

diff  --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 71b6038a2f9d2..1303a896e7eca 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -27,6 +27,15 @@ func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref<?xf32>)
   return %out : memref<10x?xf32, offset: ?, strides: [?, 1]>
 }
 
+// CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset
+func @memref_reinterpret_cast_dynamic_offset(%in: memref<?xf32>, %offset: index)
+    -> memref<10x?xf32, offset: ?, strides: [?, 1]> {
+  %out = memref.reinterpret_cast %in to
+           offset: [%offset], sizes: [10, 10], strides: [1, 1]
+           : memref<?xf32> to memref<10x?xf32, offset: ?, strides: [?, 1]>
+  return %out : memref<10x?xf32, offset: ?, strides: [?, 1]>
+}
+
 // CHECK-LABEL: func @memref_reshape(
 func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
          %shape2: memref<2xi32>, %shape3: memref<?xi32>) -> memref<*xf32> {

diff  --git a/mlir/test/mlir-cpu-runner/copy.mlir b/mlir/test/mlir-cpu-runner/copy.mlir
index ae902f8e12006..e5a471fe204dc 100644
--- a/mlir/test/mlir-cpu-runner/copy.mlir
+++ b/mlir/test/mlir-cpu-runner/copy.mlir
@@ -35,9 +35,9 @@ func @main() -> () {
   // CHECK-NEXT: [3,   4,   5]
 
   %copy_two = memref.alloc() : memref<3x2xf32>
-  %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2,3], strides:[1, 2]
-    : memref<3x2xf32> to memref<2x3xf32>
-  memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32>
+  %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides:[1, 2]
+    : memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
+  memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
   %unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32>
   call @print_memref_f32(%unranked_copy_two) : (memref<*xf32>) -> ()
   // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1]


        


More information about the Mlir-commits mailing list