[Mlir-commits] [mlir] 0724911 - [mlir] Add `tensor.reshape`.

Alexander Belyaev llvmlistbot at llvm.org
Thu Apr 22 05:53:39 PDT 2021

Author: Alexander Belyaev
Date: 2021-04-22T14:53:23+02:00
New Revision: 0724911d2a7b10ca4b8f8bbecd754143a2bed3db

URL: https://github.com/llvm/llvm-project/commit/0724911d2a7b10ca4b8f8bbecd754143a2bed3db
DIFF: https://github.com/llvm/llvm-project/commit/0724911d2a7b10ca4b8f8bbecd754143a2bed3db.diff

LOG: [mlir] Add `tensor.reshape`.

This operation a counterpart of `memref.reshape`.

RFC [Reshape Ops Restructuring](https://llvm.discourse.group/t/rfc-reshape-ops-restructuring/3310)

Differential Revision: https://reviews.llvm.org/D100971




diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 58c9a5c30a735..a0e473873d27a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -182,6 +182,67 @@ def Tensor_GenerateOp : Tensor_Op<"generate",
   let hasCanonicalizer = 1;
+// ReshapeOp
+def Tensor_ReshapeOp: Tensor_Op<"reshape", [NoSideEffect]>  {
+  let summary = "tensor reshape operation";
+  let description = [{
+    The `reshape` operation converts a tensor from one type to an equivalent
+    type with a provided shape. The source and destination types are compatible
+    if both have the same element type, same number of elements. The following
+    combinations are possible:
+    a. Source type is ranked or unranked. Shape argument has static size.
+    Result type is ranked.
+    ```mlir
+    // Reshape statically-shaped tensor.
+    %dst = tensor.reshape %src(%shape)
+             : (tensor<4x1xf32>, tensor<1xi32>) -> tensor<4xf32>
+    %dst0 = tensor.reshape %src(%shape0)
+             : (tensor<4x1xf32>, tensor<2xi32>) -> tensor<2x2xf32>
+    // Flatten unranked tensor.
+    %dst = tensor.reshape %src(%shape)
+             : (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
+    ```
+    b. Source type is ranked or unranked. Shape argument has dynamic size.
+    Result type is unranked.
+    ```mlir
+    // Reshape dynamically-shaped 1D tensor.
+    %dst = tensor.reshape %src(%shape)
+             : (tensor<?xf32>, tensor<?xi32>) -> tensor<*xf32>
+    // Reshape unranked tensor.
+    %dst = tensor.reshape %src(%shape)
+             : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
+    ```
+  }];
+  let arguments = (ins
+    AnyTensor:$source,
+    TensorRankOf<[AnySignlessInteger, Index], [1]>:$shape
+  );
+  let results = (outs AnyTensor:$result);
+  let builders = [OpBuilder<
+     (ins "TensorType":$resultType, "Value":$operand, "Value":$shape), [{
+       $_state.addOperands(operand);
+       $_state.addOperands(shape);
+       $_state.addTypes(resultType);
+     }]>];
+  let extraClassDeclaration = [{
+    TensorType getResultType() { return getResult().getType().cast<TensorType>(); }
+  }];
+  let assemblyFormat = [{
+    $source `(` $shape `)` attr-dict `:` functional-type(operands, results)
+  }];
 // YieldOp

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 9dc9240cc4623..1beb458df1f53 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -442,6 +442,47 @@ void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
+// ReshapeOp
+static int64_t GetNumElements(ShapedType type) {
+  int64_t numElements = 1;
+  for (auto dim : type.getShape())
+    numElements *= dim;
+  return numElements;
+static LogicalResult verify(ReshapeOp op) {
+  TensorType operandType = op.source().getType().cast<TensorType>();
+  TensorType resultType = op.result().getType().cast<TensorType>();
+  if (operandType.getElementType() != resultType.getElementType())
+    return op.emitOpError("element types of source and destination tensor "
+                          "types should be the same");
+  int64_t shapeSize =
+      op.shape().getType().cast<RankedTensorType>().getDimSize(0);
+  auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
+  auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
+  if (resultRankedType) {
+    if (operandRankedType && resultRankedType.hasStaticShape() &&
+        operandRankedType.hasStaticShape()) {
+      if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType))
+        return op.emitOpError("source and destination tensor should have the "
+                              "same number of elements");
+    }
+    if (shapeSize == TensorType::kDynamicSize)
+      return op.emitOpError("cannot use shape operand with dynamic length to "
+                            "reshape to statically-ranked tensor type");
+    if (shapeSize != resultRankedType.getRank())
+      return op.emitOpError(
+          "length of shape operand 
diff ers from the result's tensor rank");
+  }
+  return success();
 // TableGen'd op method definitions

diff  --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 11866990c885a..79fef8c0f8e47 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -97,3 +97,36 @@ func @tensor.generate(%m : index, %n : index)
   } : tensor<?x3x?xf32>
   return %tnsr : tensor<?x3x?xf32>
+// -----
+func @tensor.reshape_element_type_mismatch(
+       %buf: tensor<*xf32>, %shape: tensor<1xi32>) {
+  // expected-error @+1 {{element types of source and destination tensor types should be the same}}
+  tensor.reshape %buf(%shape) : (tensor<*xf32>, tensor<1xi32>) -> tensor<?xi32>
+// -----
+func @tensor.reshape_dst_ranked_shape_unranked(
+       %buf: tensor<*xf32>, %shape: tensor<?xi32>) {
+  // expected-error @+1 {{cannot use shape operand with dynamic length to reshape to statically-ranked tensor type}}
+  tensor.reshape %buf(%shape) : (tensor<*xf32>, tensor<?xi32>) -> tensor<?xf32>
+// -----
+func @tensor.reshape_dst_shape_rank_mismatch(
+       %buf: tensor<*xf32>, %shape: tensor<1xi32>) {
+  // expected-error @+1 {{length of shape operand 
diff ers from the result's tensor rank}}
+  tensor.reshape %buf(%shape)
+    : (tensor<*xf32>, tensor<1xi32>) -> tensor<?x?xf32>
+// -----
+func @tensor.reshape_num_elements_mismatch(
+       %buf: tensor<1xf32>, %shape: tensor<1xi32>) {
+  // expected-error @+1 {{source and destination tensor should have the same number of elements}}
+  tensor.reshape %buf(%shape)
+    : (tensor<1xf32>, tensor<1xi32>) -> tensor<10xf32>

diff  --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 9b15712058a25..450da06a25938 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -53,3 +53,15 @@ func @tensor.generate(%m : index, %n : index)
   } : tensor<?x3x?xf32>
   return %tnsr : tensor<?x3x?xf32>
+// CHECK-LABEL: func @tensor_reshape
+func @tensor_reshape(%unranked: tensor<*xf32>, %shape1: tensor<1xi32>,
+         %shape2: tensor<2xi32>, %shape3: tensor<?xi32>) -> tensor<*xf32> {
+  %dyn_vec = tensor.reshape %unranked(%shape1)
+               : (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
+  %dyn_mat = tensor.reshape %dyn_vec(%shape2)
+               : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+  %new_unranked = tensor.reshape %dyn_mat(%shape3)
+               : (tensor<?x?xf32>, tensor<?xi32>) -> tensor<*xf32>
+  return %new_unranked : tensor<*xf32>


More information about the Mlir-commits mailing list