[Mlir-commits] [mlir] d2ed2f1 - [mlir] Add MemRefReshapeOp definition to Standard.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Oct 22 04:30:06 PDT 2020
Author: Alexander Belyaev
Date: 2020-10-22T13:29:13+02:00
New Revision: d2ed2f16b853a936c8d0c1c1fc406e7b8e54526c
URL: https://github.com/llvm/llvm-project/commit/d2ed2f16b853a936c8d0c1c1fc406e7b8e54526c
DIFF: https://github.com/llvm/llvm-project/commit/d2ed2f16b853a936c8d0c1c1fc406e7b8e54526c.diff
LOG: [mlir] Add MemRefReshapeOp definition to Standard.
https://llvm.discourse.group/t/rfc-standard-memref-cast-ops/1454/15
Differential Revision: https://reviews.llvm.org/D89784
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/invalid.mlir
mlir/test/Dialect/Standard/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 31fb780cf145..6de8ace044cc 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2217,6 +2217,70 @@ def MemRefCastOp : CastOp<"memref_cast", [
}];
}
+//===----------------------------------------------------------------------===//
+// MemRefReshapeOp
+//===----------------------------------------------------------------------===//
+
+def MemRefReshapeOp: Std_Op<"memref_reshape", [
+ ViewLikeOpInterface, NoSideEffect]> {
+ let summary = "memref reshape operation";
+ let description = [{
+ The `memref_reshape` operation converts a memref from one type to an
+ equivalent type with a provided shape. The data is never copied or
+ modified. The source and destination types are compatible if both have the
+ same element type, same number of elements, address space and identity
+ layout map. The following combinations are possible:
+
+ a. Source type is ranked or unranked. Shape argument has static size.
+ Result type is ranked.
+
+ ```mlir
+ // Reshape statically-shaped memref.
+ %dst = memref_reshape %src(%shape)
+ : (memref<4x1xf32>, memref<1xi32>) to memref<4xf32>
+ %dst0 = memref_reshape %src(%shape0)
+ : (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32>
+ // Flatten unranked memref.
+ %dst = memref_reshape %src(%shape)
+ : (memref<*xf32>, memref<1xi32>) to memref<?xf32>
+ ```
+
+ a. Source type is ranked or unranked. Shape argument has dynamic size.
+ Result type is unranked.
+
+ ```mlir
+ // Reshape dynamically-shaped 1D memref.
+ %dst = memref_reshape %src(%shape)
+ : (memref<?xf32>, memref<?xi32>) to memref<*xf32>
+ // Reshape unranked memref.
+ %dst = memref_reshape %src(%shape)
+ : (memref<*xf32>, memref<?xi32>) to memref<*xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ AnyRankedOrUnrankedMemRef:$source,
+ MemRefRankOf<[AnySignlessInteger], [1]>:$shape
+ );
+ let results = (outs AnyRankedOrUnrankedMemRef:$result);
+
+ let builders = [OpBuilder<
+ "MemRefType resultType, Value operand, Value shape", [{
+ $_state.addOperands(operand);
+ $_state.addOperands(shape);
+ $_state.addTypes(resultType);
+ }]>];
+
+ let extraClassDeclaration = [{
+ MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
+ Value getViewSource() { return source(); }
+ }];
+
+ let assemblyFormat = [{
+ $source `(` $shape `)` attr-dict `:` functional-type(operands, results)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 9fe94fe75327..7010ad5f34c9 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2145,6 +2145,41 @@ OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
return impl::foldCastOp(*this);
}
+//===----------------------------------------------------------------------===//
+// MemRefReshapeOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(MemRefReshapeOp op) {
+ Type operandType = op.source().getType();
+ Type resultType = op.result().getType();
+
+ Type operandElementType = operandType.cast<ShapedType>().getElementType();
+ Type resultElementType = resultType.cast<ShapedType>().getElementType();
+ if (operandElementType != resultElementType)
+ return op.emitOpError("element types of source and destination memref "
+ "types should be the same");
+
+ if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
+ if (!operandMemRefType.getAffineMaps().empty())
+ return op.emitOpError(
+ "source memref type should have identity affine map");
+
+ int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
+ auto resultMemRefType = resultType.dyn_cast<MemRefType>();
+ if (resultMemRefType) {
+ if (!resultMemRefType.getAffineMaps().empty())
+ return op.emitOpError(
+ "result memref type should have identity affine map");
+ if (shapeSize == ShapedType::kDynamicSize)
+ return op.emitOpError("cannot use shape operand with dynamic length to "
+ "reshape to statically-ranked memref type");
+ if (shapeSize != resultMemRefType.getRank())
+ return op.emitOpError(
+ "length of shape operand
diff ers from the result's memref rank");
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index 72fe5c227578..8047ad94f588 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -102,3 +102,53 @@ func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off
// expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
+
+// -----
+
+// CHECK-LABEL: memref_reshape_element_type_mismatch
+func @memref_reshape_element_type_mismatch(
+ %buf: memref<*xf32>, %shape: memref<1xi32>) {
+ // expected-error @+1 {{element types of source and destination memref types should be the same}}
+ memref_reshape %buf(%shape) : (memref<*xf32>, memref<1xi32>) -> memref<?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: memref_reshape_dst_ranked_shape_unranked
+func @memref_reshape_dst_ranked_shape_unranked(
+ %buf: memref<*xf32>, %shape: memref<?xi32>) {
+ // expected-error @+1 {{cannot use shape operand with dynamic length to reshape to statically-ranked memref type}}
+ memref_reshape %buf(%shape) : (memref<*xf32>, memref<?xi32>) -> memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: memref_reshape_dst_shape_rank_mismatch
+func @memref_reshape_dst_shape_rank_mismatch(
+ %buf: memref<*xf32>, %shape: memref<1xi32>) {
+ // expected-error @+1 {{length of shape operand
diff ers from the result's memref rank}}
+ memref_reshape %buf(%shape)
+ : (memref<*xf32>, memref<1xi32>) -> memref<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: memref_reshape_src_affine_map_is_not_identity
+func @memref_reshape_src_affine_map_is_not_identity(
+ %buf: memref<4x4xf32, offset: 0, strides: [3, 2]>,
+ %shape: memref<1xi32>) {
+ // expected-error @+1 {{source memref type should have identity affine map}}
+ memref_reshape %buf(%shape)
+ : (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>)
+ -> memref<8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: memref_reshape_result_affine_map_is_not_identity
+func @memref_reshape_result_affine_map_is_not_identity(
+ %buf: memref<4x4xf32>, %shape: memref<1xi32>) {
+ // expected-error @+1 {{result memref type should have identity affine map}}
+ memref_reshape %buf(%shape)
+ : (memref<4x4xf32>, memref<1xi32>) -> memref<8xf32, offset: 0, strides: [2]>
+}
diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index b11c9534cc2d..501ff07fd5a7 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -54,3 +54,15 @@ func @atan2(%arg0 : f32, %arg1 : f32) -> f32 {
%result = atan2 %arg0, %arg1 : f32
return %result : f32
}
+
+// CHECK-LABEL: func @memref_reshape(
+func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
+ %shape2: memref<2xi32>, %shape3: memref<?xi32>) -> memref<*xf32> {
+ %dyn_vec = memref_reshape %unranked(%shape1)
+ : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
+ %dyn_mat = memref_reshape %dyn_vec(%shape2)
+ : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
+ %new_unranked = memref_reshape %dyn_mat(%shape3)
+ : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
+ return %new_unranked : memref<*xf32>
+}
More information about the Mlir-commits
mailing list