[Mlir-commits] [mlir] 28ab10f - [mlir][memref] ReinterpretCast: allow static sizes/strides/offset where affine map expects dynamic

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 21 05:33:47 PST 2021


Author: Butygin
Date: 2021-12-21T16:20:01+03:00
New Revision: 28ab10f40424955761fe53485ca243c24f6fab2f

URL: https://github.com/llvm/llvm-project/commit/28ab10f40424955761fe53485ca243c24f6fab2f
DIFF: https://github.com/llvm/llvm-project/commit/28ab10f40424955761fe53485ca243c24f6fab2f.diff

LOG: [mlir][memref] ReinterpretCast: allow static sizes/strides/offset where affine map expects dynamic

* There is no reason to forbid that case
* Also, user will get very unfriendly error like `expected result type with offset = -9223372036854775808 instead of 1`

Differential Revision: https://reviews.llvm.org/D114678

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/Dialect/MemRef/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1916ffe36dd66..aa201370c0cfd 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1158,7 +1158,7 @@ static LogicalResult verify(ReinterpretCastOp op) {
                                  extractFromI64ArrayAttr(op.static_sizes())))) {
     int64_t resultSize = std::get<0>(en.value());
     int64_t expectedSize = std::get<1>(en.value());
-    if (resultSize != expectedSize)
+    if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
       return op.emitError("expected result type with size = ")
              << expectedSize << " instead of " << resultSize
              << " in dim = " << en.index();
@@ -1175,7 +1175,8 @@ static LogicalResult verify(ReinterpretCastOp op) {
     // Match offset in result memref type and in static_offsets attribute.
     int64_t expectedOffset =
         extractFromI64ArrayAttr(op.static_offsets()).front();
-    if (resultOffset != expectedOffset)
+    if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
+        resultOffset != expectedOffset)
       return op.emitError("expected result type with offset = ")
              << resultOffset << " instead of " << expectedOffset;
 
@@ -1184,7 +1185,8 @@ static LogicalResult verify(ReinterpretCastOp op) {
              resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
       int64_t resultStride = std::get<0>(en.value());
       int64_t expectedStride = std::get<1>(en.value());
-      if (resultStride != expectedStride)
+      if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
+          resultStride != expectedStride)
         return op.emitError("expected result type with stride = ")
                << expectedStride << " instead of " << resultStride
                << " in dim = " << en.index();

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 55c5a821fb3dd..97d9db8cf1cca 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -208,18 +208,6 @@ func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
 
 // -----
 
-func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
-  %c0 = arith.constant 0 : index
-  %c10 = arith.constant 10 : index
-  // expected-error @+1 {{expected result type with size = 10 instead of -1 in dim = 0}}
-  %out = memref.reinterpret_cast %in to
-           offset: [%c0], sizes: [10, %c10], strides: [%c10, 1]
-           : memref<?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
-  return
-}
-
-// -----
-
 func @memref_reshape_element_type_mismatch(
        %buf: memref<*xf32>, %shape: memref<1xi32>) {
   // expected-error @+1 {{element types of source and destination memref types should be the same}}

diff  --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 4ff2f8b5517be..963c817af3981 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -18,6 +18,15 @@ func @memref_reinterpret_cast(%in: memref<?xf32>)
   return %out : memref<10x?xf32, offset: ?, strides: [?, 1]>
 }
 
+// CHECK-LABEL: func @memref_reinterpret_cast_static_to_dynamic_sizes
+func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref<?xf32>)
+    -> memref<10x?xf32, offset: ?, strides: [?, 1]> {
+  %out = memref.reinterpret_cast %in to
+           offset: [1], sizes: [10, 10], strides: [1, 1]
+           : memref<?xf32> to memref<10x?xf32, offset: ?, strides: [?, 1]>
+  return %out : memref<10x?xf32, offset: ?, strides: [?, 1]>
+}
+
 // CHECK-LABEL: func @memref_reshape(
 func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
          %shape2: memref<2xi32>, %shape3: memref<?xi32>) -> memref<*xf32> {


        


More information about the Mlir-commits mailing list