[Mlir-commits] [mlir] d2a9569 - [mlir][Linalg] Allow reshapes to collapse to a zero-rank tensor.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 12 23:03:59 PDT 2020
Author: MaheshRavishankar
Date: 2020-05-12T23:03:25-07:00
New Revision: d2a9569850166083638a8fb88f7dae1c1b62a926
URL: https://github.com/llvm/llvm-project/commit/d2a9569850166083638a8fb88f7dae1c1b62a926
DIFF: https://github.com/llvm/llvm-project/commit/d2a9569850166083638a8fb88f7dae1c1b62a926.diff
LOG: [mlir][Linalg] Allow reshapes to collapse to a zero-rank tensor.
This is only valid if the source tensors (result tensor) is static
shaped with all unit-extents when the reshape is collapsing
(expanding) dimensions.
Differential Revision: https://reviews.llvm.org/D79764
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/llvm.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index af1bc121bf3b..874bda002a59 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -111,6 +111,13 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape">,
rank to obtain the memref with the smaller rank. In the case of a dimension
expansion, the reassociation maps can be interpreted as inverse maps.
+ The result memref type of a reshape when dimensions are collapsed
+ (operand memref type when dimensions are expanded) can be
+ zero-ranked if the operand memref type (or the result memref type
+ when dimensions are expanded) is statically shaped with all
+ dimensions being unit extent. In such cases the reassociation map
+ is empty.
+
Examples:
```mlir
@@ -152,6 +159,13 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
rank to obtain the tensor with the smaller rank. In the case of a dimension
expansion, the reassociation maps can be interpreted as inverse maps.
+ The result tensor type of a reshape when dimensions are collapsed
+ (operand tensor type when dimensions are expanded) can be
+ zero-ranked if the operand tensor type (or the result tensor type
+ when dimensions are expanded) is statically shaped with all
+ dimensions being unit extent. In such cases the reassociation map
+ is empty.
+
Examples:
```mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index fc2353e4087e..4b1b8cde639e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -438,11 +438,21 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T &expandedType,
std::swap(expandedRank, collapsedRank);
std::swap(expandedType, collapsedType);
}
- if (expandedRank == 0 || collapsedRank == 0)
+ if (expandedRank == 0)
return op.emitOpError("expected non-zero memref ranks");
if (expandedRank == collapsedRank)
return op.emitOpError("expected to collapse or expand dims");
+ if (collapsedRank == 0) {
+ // If collapsed rank is 0, then expanded type must be static shaped and of
+ // sizes 1.
+ if (llvm::any_of(expandedType.getShape(),
+ [](int64_t dim) -> bool { return dim != 1; }))
+ return op.emitOpError(
+ "invalid to reshape tensor/memref with non-unit extent dimensions to "
+ "zero-rank tensor/memref");
+ return success();
+ }
if (collapsedRank != op.reassociation().size())
return op.emitOpError("expected rank of the collapsed type(")
<< collapsedRank << ") to be the number of reassociation maps("
diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index 290e1a2fe4d7..84b3bcb66940 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -273,3 +273,32 @@ func @reshape_static(%arg0: memref<3x4x5xf32>) {
// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+
+func @reshape_zero_dim(%arg0 : memref<1x1xf32>) {
+ %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
+ %1 = linalg.reshape %0 [] : memref<f32> into memref<1x1xf32>
+ return
+}
+// CHECK-LABEL: func @reshape_zero_dim
+// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }">
+// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
+// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
+// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64 }">
+// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64 }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 9b85b2874658..5237db79c42e 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -630,3 +630,26 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x
// CHECK: linalg.batch_matmul
// CHECK: linalg.batch_matmul
+// -----
+
+func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>) -> (tensor<f32>, tensor<1x1xf32>)
+{
+ %0 = linalg.tensor_reshape %arg0 [] : tensor<1x1xf32> into tensor<f32>
+ %1 = linalg.tensor_reshape %0 [] : tensor<f32> into tensor<1x1xf32>
+ return %0, %1 : tensor<f32>, tensor<1x1xf32>
+}
+// CHECK-LABEL: func @tensor_reshape_zero_dim
+// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
+// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
+
+// -----
+
+func @memref_reshape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>) -> (memref<f32>, memref<1x1xf32>)
+{
+ %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
+ %1 = linalg.reshape %0 [] : memref<f32> into memref<1x1xf32>
+ return %0, %1 : memref<f32>, memref<1x1xf32>
+}
+// CHECK-LABEL: func @memref_reshape_zero_dim
+// CHECK: linalg.reshape %{{.*}} [] : memref<1x1xf32> into memref<f32>
+// CHECK: linalg.reshape %{{.*}} [] : memref<f32> into memref<1x1xf32>
More information about the Mlir-commits
mailing list