[Mlir-commits] [mlir] 0724911 - [mlir] Add `tensor.reshape`.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Apr 22 05:53:39 PDT 2021
Author: Alexander Belyaev
Date: 2021-04-22T14:53:23+02:00
New Revision: 0724911d2a7b10ca4b8f8bbecd754143a2bed3db
URL: https://github.com/llvm/llvm-project/commit/0724911d2a7b10ca4b8f8bbecd754143a2bed3db
DIFF: https://github.com/llvm/llvm-project/commit/0724911d2a7b10ca4b8f8bbecd754143a2bed3db.diff
LOG: [mlir] Add `tensor.reshape`.
This operation a counterpart of `memref.reshape`.
RFC [Reshape Ops Restructuring](https://llvm.discourse.group/t/rfc-reshape-ops-restructuring/3310)
Differential Revision: https://reviews.llvm.org/D100971
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/invalid.mlir
mlir/test/Dialect/Tensor/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 58c9a5c30a735..a0e473873d27a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -182,6 +182,67 @@ def Tensor_GenerateOp : Tensor_Op<"generate",
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// ReshapeOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_ReshapeOp: Tensor_Op<"reshape", [NoSideEffect]> {
+ let summary = "tensor reshape operation";
+ let description = [{
+ The `reshape` operation converts a tensor from one type to an equivalent
+ type with a provided shape. The source and destination types are compatible
+ if both have the same element type, same number of elements. The following
+ combinations are possible:
+
+ a. Source type is ranked or unranked. Shape argument has static size.
+ Result type is ranked.
+
+ ```mlir
+ // Reshape statically-shaped tensor.
+ %dst = tensor.reshape %src(%shape)
+ : (tensor<4x1xf32>, tensor<1xi32>) -> tensor<4xf32>
+ %dst0 = tensor.reshape %src(%shape0)
+ : (tensor<4x1xf32>, tensor<2xi32>) -> tensor<2x2xf32>
+ // Flatten unranked tensor.
+ %dst = tensor.reshape %src(%shape)
+ : (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
+ ```
+
+ b. Source type is ranked or unranked. Shape argument has dynamic size.
+ Result type is unranked.
+
+ ```mlir
+ // Reshape dynamically-shaped 1D tensor.
+ %dst = tensor.reshape %src(%shape)
+ : (tensor<?xf32>, tensor<?xi32>) -> tensor<*xf32>
+ // Reshape unranked tensor.
+ %dst = tensor.reshape %src(%shape)
+ : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ AnyTensor:$source,
+ TensorRankOf<[AnySignlessInteger, Index], [1]>:$shape
+ );
+ let results = (outs AnyTensor:$result);
+
+ let builders = [OpBuilder<
+ (ins "TensorType":$resultType, "Value":$operand, "Value":$shape), [{
+ $_state.addOperands(operand);
+ $_state.addOperands(shape);
+ $_state.addTypes(resultType);
+ }]>];
+
+ let extraClassDeclaration = [{
+ TensorType getResultType() { return getResult().getType().cast<TensorType>(); }
+ }];
+
+ let assemblyFormat = [{
+ $source `(` $shape `)` attr-dict `:` functional-type(operands, results)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 9dc9240cc4623..1beb458df1f53 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -442,6 +442,47 @@ void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
StaticTensorGenerate>(context);
}
+//===----------------------------------------------------------------------===//
+// ReshapeOp
+//===----------------------------------------------------------------------===//
+
+static int64_t GetNumElements(ShapedType type) {
+ int64_t numElements = 1;
+ for (auto dim : type.getShape())
+ numElements *= dim;
+ return numElements;
+}
+
+static LogicalResult verify(ReshapeOp op) {
+ TensorType operandType = op.source().getType().cast<TensorType>();
+ TensorType resultType = op.result().getType().cast<TensorType>();
+
+ if (operandType.getElementType() != resultType.getElementType())
+ return op.emitOpError("element types of source and destination tensor "
+ "types should be the same");
+
+ int64_t shapeSize =
+ op.shape().getType().cast<RankedTensorType>().getDimSize(0);
+ auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
+ auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
+
+ if (resultRankedType) {
+ if (operandRankedType && resultRankedType.hasStaticShape() &&
+ operandRankedType.hasStaticShape()) {
+ if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType))
+ return op.emitOpError("source and destination tensor should have the "
+ "same number of elements");
+ }
+ if (shapeSize == TensorType::kDynamicSize)
+ return op.emitOpError("cannot use shape operand with dynamic length to "
+ "reshape to statically-ranked tensor type");
+ if (shapeSize != resultRankedType.getRank())
+ return op.emitOpError(
+ "length of shape operand
diff ers from the result's tensor rank");
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 11866990c885a..79fef8c0f8e47 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -97,3 +97,36 @@ func @tensor.generate(%m : index, %n : index)
} : tensor<?x3x?xf32>
return %tnsr : tensor<?x3x?xf32>
}
+// -----
+
+func @tensor.reshape_element_type_mismatch(
+ %buf: tensor<*xf32>, %shape: tensor<1xi32>) {
+ // expected-error @+1 {{element types of source and destination tensor types should be the same}}
+ tensor.reshape %buf(%shape) : (tensor<*xf32>, tensor<1xi32>) -> tensor<?xi32>
+}
+
+// -----
+
+func @tensor.reshape_dst_ranked_shape_unranked(
+ %buf: tensor<*xf32>, %shape: tensor<?xi32>) {
+ // expected-error @+1 {{cannot use shape operand with dynamic length to reshape to statically-ranked tensor type}}
+ tensor.reshape %buf(%shape) : (tensor<*xf32>, tensor<?xi32>) -> tensor<?xf32>
+}
+
+// -----
+
+func @tensor.reshape_dst_shape_rank_mismatch(
+ %buf: tensor<*xf32>, %shape: tensor<1xi32>) {
+ // expected-error @+1 {{length of shape operand
diff ers from the result's tensor rank}}
+ tensor.reshape %buf(%shape)
+ : (tensor<*xf32>, tensor<1xi32>) -> tensor<?x?xf32>
+}
+
+// -----
+
+func @tensor.reshape_num_elements_mismatch(
+ %buf: tensor<1xf32>, %shape: tensor<1xi32>) {
+ // expected-error @+1 {{source and destination tensor should have the same number of elements}}
+ tensor.reshape %buf(%shape)
+ : (tensor<1xf32>, tensor<1xi32>) -> tensor<10xf32>
+}
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 9b15712058a25..450da06a25938 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -53,3 +53,15 @@ func @tensor.generate(%m : index, %n : index)
} : tensor<?x3x?xf32>
return %tnsr : tensor<?x3x?xf32>
}
+
+// CHECK-LABEL: func @tensor_reshape
+func @tensor_reshape(%unranked: tensor<*xf32>, %shape1: tensor<1xi32>,
+ %shape2: tensor<2xi32>, %shape3: tensor<?xi32>) -> tensor<*xf32> {
+ %dyn_vec = tensor.reshape %unranked(%shape1)
+ : (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
+ %dyn_mat = tensor.reshape %dyn_vec(%shape2)
+ : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+ %new_unranked = tensor.reshape %dyn_mat(%shape3)
+ : (tensor<?x?xf32>, tensor<?xi32>) -> tensor<*xf32>
+ return %new_unranked : tensor<*xf32>
+}
More information about the Mlir-commits
mailing list