[Mlir-commits] [mlir] 33cec20 - [mlir][memref] Tighten verification of memref.reinterpret_cast
Stephan Herhut
llvmlistbot at llvm.org
Mon Jan 10 02:56:12 PST 2022
Author: Stephan Herhut
Date: 2022-01-10T11:55:47+01:00
New Revision: 33cec20dbd3bb8a896084da4a87f00d1cf13f77a
URL: https://github.com/llvm/llvm-project/commit/33cec20dbd3bb8a896084da4a87f00d1cf13f77a
DIFF: https://github.com/llvm/llvm-project/commit/33cec20dbd3bb8a896084da4a87f00d1cf13f77a.diff
LOG: [mlir][memref] Tighten verification of memref.reinterpret_cast
We allow the omission of a map in memref.reinterpret_cast under the assumption,
that the cast might cast to an identity layout. This change adds verification
that the static knowledge that is present in the reinterpret_cast supports
this assumption.
Differential Revision: https://reviews.llvm.org/D116601
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/MemRef/invalid.mlir
mlir/test/Dialect/MemRef/ops.mlir
mlir/test/mlir-cpu-runner/copy.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5dec3b717bcf0..2e7dc4112592b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1155,40 +1155,44 @@ static LogicalResult verify(ReinterpretCastOp op) {
extractFromI64ArrayAttr(op.static_sizes())))) {
int64_t resultSize = std::get<0>(en.value());
int64_t expectedSize = std::get<1>(en.value());
- if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
+ if (!ShapedType::isDynamic(resultSize) &&
+ !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
return op.emitError("expected result type with size = ")
<< expectedSize << " instead of " << resultSize
<< " in dim = " << en.index();
}
- // Match offset and strides in static_offset and static_strides attributes if
- // result memref type has an affine map specified.
- if (!resultType.getLayout().isIdentity()) {
- int64_t resultOffset;
- SmallVector<int64_t, 4> resultStrides;
- if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
- return failure();
-
- // Match offset in result memref type and in static_offsets attribute.
- int64_t expectedOffset =
- extractFromI64ArrayAttr(op.static_offsets()).front();
- if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
- resultOffset != expectedOffset)
- return op.emitError("expected result type with offset = ")
- << resultOffset << " instead of " << expectedOffset;
-
- // Match strides in result memref type and in static_strides attribute.
- for (auto &en : llvm::enumerate(llvm::zip(
- resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
- int64_t resultStride = std::get<0>(en.value());
- int64_t expectedStride = std::get<1>(en.value());
- if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
- resultStride != expectedStride)
- return op.emitError("expected result type with stride = ")
- << expectedStride << " instead of " << resultStride
- << " in dim = " << en.index();
- }
+ // Match offset and strides in static_offset and static_strides attributes. If
+ // result memref type has no affine map specified, this will assume an
+ // identity layout.
+ int64_t resultOffset;
+ SmallVector<int64_t, 4> resultStrides;
+ if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+ return op.emitError(
+ "expected result type to have strided layout but found ")
+ << resultType;
+
+ // Match offset in result memref type and in static_offsets attribute.
+ int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front();
+ if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
+ !ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
+ resultOffset != expectedOffset)
+ return op.emitError("expected result type with offset = ")
+ << resultOffset << " instead of " << expectedOffset;
+
+ // Match strides in result memref type and in static_strides attribute.
+ for (auto &en : llvm::enumerate(llvm::zip(
+ resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
+ int64_t resultStride = std::get<0>(en.value());
+ int64_t expectedStride = std::get<1>(en.value());
+ if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
+ !ShapedType::isDynamicStrideOrOffset(expectedStride) &&
+ resultStride != expectedStride)
+ return op.emitError("expected result type with stride = ")
+ << expectedStride << " instead of " << resultStride
+ << " in dim = " << en.index();
}
+
return success();
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 2e81705049f54..6ddb49de9932e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -151,7 +151,7 @@ func @multiple_reducing_dims_all_dynamic(%arg0 : memref<?x?x?xf32, offset: ?, st
// CHECK: return %[[SIZE]] : index
func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index {
%c0 = arith.constant 0 : index
- %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [0] : memref<?xi8> to memref<?xi8>
+ %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [1] : memref<?xi8> to memref<?xi8>
%1 = memref.dim %0, %c0 : memref<?xi8>
return %1 : index
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 90f851959748c..51dce9c8b20bb 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -208,6 +208,44 @@ func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
// -----
+func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {
+ // expected-error @+1 {{expected result type with offset = 0 instead of 2}}
+ %out = memref.reinterpret_cast %in to offset: [2], sizes: [10], strides: [1]
+ : memref<?xf32> to memref<10xf32>
+ return
+}
+
+// -----
+
+func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
+ // expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}}
+ %out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]
+ : memref<?xf32> to memref<10xf32>
+ return
+}
+
+// -----
+
+func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
+ // expected-error @+1 {{expected result type with stride = 42 instead of 10 in dim = 0}}
+ %out = memref.reinterpret_cast %in to
+ offset: [0], sizes: [9, 10], strides: [42, 1]
+ : memref<?x?xf32> to memref<9x10xf32>
+ return
+}
+
+// -----
+
+func @memref_reinterpret_cast_non_strided_layout(%in: memref<?x?xf32>) {
+ // expected-error @+1 {{expected result type to have strided layout but found 'memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>}}
+ %out = memref.reinterpret_cast %in to
+ offset: [0], sizes: [9, 10], strides: [42, 1]
+ : memref<?x?xf32> to memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>
+ return
+}
+
+// -----
+
func @memref_reshape_element_type_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{element types of source and destination memref types should be the same}}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 71b6038a2f9d2..1303a896e7eca 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -27,6 +27,15 @@ func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref<?xf32>)
return %out : memref<10x?xf32, offset: ?, strides: [?, 1]>
}
+// CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset
+func @memref_reinterpret_cast_dynamic_offset(%in: memref<?xf32>, %offset: index)
+ -> memref<10x?xf32, offset: ?, strides: [?, 1]> {
+ %out = memref.reinterpret_cast %in to
+ offset: [%offset], sizes: [10, 10], strides: [1, 1]
+ : memref<?xf32> to memref<10x?xf32, offset: ?, strides: [?, 1]>
+ return %out : memref<10x?xf32, offset: ?, strides: [?, 1]>
+}
+
// CHECK-LABEL: func @memref_reshape(
func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
%shape2: memref<2xi32>, %shape3: memref<?xi32>) -> memref<*xf32> {
diff --git a/mlir/test/mlir-cpu-runner/copy.mlir b/mlir/test/mlir-cpu-runner/copy.mlir
index ae902f8e12006..e5a471fe204dc 100644
--- a/mlir/test/mlir-cpu-runner/copy.mlir
+++ b/mlir/test/mlir-cpu-runner/copy.mlir
@@ -35,9 +35,9 @@ func @main() -> () {
// CHECK-NEXT: [3, 4, 5]
%copy_two = memref.alloc() : memref<3x2xf32>
- %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2,3], strides:[1, 2]
- : memref<3x2xf32> to memref<2x3xf32>
- memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32>
+ %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides:[1, 2]
+ : memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
+ memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
%unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32>
call @print_memref_f32(%unranked_copy_two) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1]
More information about the Mlir-commits
mailing list