[Mlir-commits] [mlir] d2ed2f1 - [mlir] Add MemRefReshapeOp definition to Standard.

Alexander Belyaev llvmlistbot at llvm.org
Thu Oct 22 04:30:06 PDT 2020


Author: Alexander Belyaev
Date: 2020-10-22T13:29:13+02:00
New Revision: d2ed2f16b853a936c8d0c1c1fc406e7b8e54526c

URL: https://github.com/llvm/llvm-project/commit/d2ed2f16b853a936c8d0c1c1fc406e7b8e54526c
DIFF: https://github.com/llvm/llvm-project/commit/d2ed2f16b853a936c8d0c1c1fc406e7b8e54526c.diff

LOG: [mlir] Add MemRefReshapeOp definition to Standard.

https://llvm.discourse.group/t/rfc-standard-memref-cast-ops/1454/15

Differential Revision: https://reviews.llvm.org/D89784

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/invalid.mlir
    mlir/test/Dialect/Standard/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 31fb780cf145..6de8ace044cc 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2217,6 +2217,70 @@ def MemRefCastOp : CastOp<"memref_cast", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// MemRefReshapeOp
+//===----------------------------------------------------------------------===//
+
+def MemRefReshapeOp: Std_Op<"memref_reshape", [
+    ViewLikeOpInterface, NoSideEffect]>  {
+  let summary = "memref reshape operation";
+  let description = [{
+    The `memref_reshape` operation converts a memref from one type to an
+    equivalent type with a provided shape. The data is never copied or
+    modified.  The source and destination types are compatible if both have the
+    same element type, same number of elements, address space and identity
+    layout map. The following combinations are possible:
+
+    a. Source type is ranked or unranked. Shape argument has static size.
+    Result type is ranked.
+
+    ```mlir
+    // Reshape statically-shaped memref.
+    %dst = memref_reshape %src(%shape)
+             : (memref<4x1xf32>, memref<1xi32>) to memref<4xf32>
+    %dst0 = memref_reshape %src(%shape0)
+             : (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32>
+    // Flatten unranked memref.
+    %dst = memref_reshape %src(%shape)
+             : (memref<*xf32>, memref<1xi32>) to memref<?xf32>
+    ```
+
+    a. Source type is ranked or unranked. Shape argument has dynamic size.
+    Result type is unranked.
+
+    ```mlir
+    // Reshape dynamically-shaped 1D memref.
+    %dst = memref_reshape %src(%shape)
+             : (memref<?xf32>, memref<?xi32>) to memref<*xf32>
+    // Reshape unranked memref.
+    %dst = memref_reshape %src(%shape)
+             : (memref<*xf32>, memref<?xi32>) to memref<*xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    AnyRankedOrUnrankedMemRef:$source,
+    MemRefRankOf<[AnySignlessInteger], [1]>:$shape
+  );
+  let results = (outs AnyRankedOrUnrankedMemRef:$result);
+
+  let builders = [OpBuilder<
+     "MemRefType resultType, Value operand, Value shape", [{
+       $_state.addOperands(operand);
+       $_state.addOperands(shape);
+       $_state.addTypes(resultType);
+     }]>];
+
+  let extraClassDeclaration = [{
+    MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
+    Value getViewSource() { return source(); }
+  }];
+
+  let assemblyFormat = [{
+    $source `(` $shape `)` attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // MulFOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 9fe94fe75327..7010ad5f34c9 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2145,6 +2145,41 @@ OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
   return impl::foldCastOp(*this);
 }
 
+//===----------------------------------------------------------------------===//
+// MemRefReshapeOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(MemRefReshapeOp op) {
+  Type operandType = op.source().getType();
+  Type resultType = op.result().getType();
+
+  Type operandElementType = operandType.cast<ShapedType>().getElementType();
+  Type resultElementType = resultType.cast<ShapedType>().getElementType();
+  if (operandElementType != resultElementType)
+    return op.emitOpError("element types of source and destination memref "
+                          "types should be the same");
+
+  if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
+    if (!operandMemRefType.getAffineMaps().empty())
+      return op.emitOpError(
+          "source memref type should have identity affine map");
+
+  int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
+  auto resultMemRefType = resultType.dyn_cast<MemRefType>();
+  if (resultMemRefType) {
+    if (!resultMemRefType.getAffineMaps().empty())
+      return op.emitOpError(
+          "result memref type should have identity affine map");
+    if (shapeSize == ShapedType::kDynamicSize)
+      return op.emitOpError("cannot use shape operand with dynamic length to "
+                            "reshape to statically-ranked memref type");
+    if (shapeSize != resultMemRefType.getRank())
+      return op.emitOpError(
+          "length of shape operand 
diff ers from the result's memref rank");
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // MulFOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index 72fe5c227578..8047ad94f588 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -102,3 +102,53 @@ func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off
   // expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
   transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
+
+// -----
+
+// CHECK-LABEL: memref_reshape_element_type_mismatch
+func @memref_reshape_element_type_mismatch(
+       %buf: memref<*xf32>, %shape: memref<1xi32>) {
+  // expected-error @+1 {{element types of source and destination memref types should be the same}}
+  memref_reshape %buf(%shape) : (memref<*xf32>, memref<1xi32>) -> memref<?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: memref_reshape_dst_ranked_shape_unranked
+func @memref_reshape_dst_ranked_shape_unranked(
+       %buf: memref<*xf32>, %shape: memref<?xi32>) {
+  // expected-error @+1 {{cannot use shape operand with dynamic length to reshape to statically-ranked memref type}}
+  memref_reshape %buf(%shape) : (memref<*xf32>, memref<?xi32>) -> memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: memref_reshape_dst_shape_rank_mismatch
+func @memref_reshape_dst_shape_rank_mismatch(
+       %buf: memref<*xf32>, %shape: memref<1xi32>) {
+  // expected-error @+1 {{length of shape operand 
diff ers from the result's memref rank}}
+  memref_reshape %buf(%shape)
+    : (memref<*xf32>, memref<1xi32>) -> memref<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: memref_reshape_src_affine_map_is_not_identity
+func @memref_reshape_src_affine_map_is_not_identity(
+        %buf: memref<4x4xf32, offset: 0, strides: [3, 2]>,
+        %shape: memref<1xi32>) {
+  // expected-error @+1 {{source memref type should have identity affine map}}
+  memref_reshape %buf(%shape)
+    : (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>)
+    -> memref<8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: memref_reshape_result_affine_map_is_not_identity
+func @memref_reshape_result_affine_map_is_not_identity(
+        %buf: memref<4x4xf32>, %shape: memref<1xi32>) {
+  // expected-error @+1 {{result memref type should have identity affine map}}
+  memref_reshape %buf(%shape)
+    : (memref<4x4xf32>, memref<1xi32>) -> memref<8xf32, offset: 0, strides: [2]>
+}

diff  --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index b11c9534cc2d..501ff07fd5a7 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -54,3 +54,15 @@ func @atan2(%arg0 : f32, %arg1 : f32) -> f32 {
   %result = atan2 %arg0, %arg1 : f32
   return %result : f32
 }
+
+// CHECK-LABEL: func @memref_reshape(
+func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
+         %shape2: memref<2xi32>, %shape3: memref<?xi32>) -> memref<*xf32> {
+  %dyn_vec = memref_reshape %unranked(%shape1)
+               : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
+  %dyn_mat = memref_reshape %dyn_vec(%shape2)
+               : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
+  %new_unranked = memref_reshape %dyn_mat(%shape3)
+               : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
+  return %new_unranked : memref<*xf32>
+}


        


More information about the Mlir-commits mailing list