[PATCH] D79764: [mlir][Linalg] Allow reshapes to collapse to a zero-rank tensor.
Mahesh Ravishankar via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Mon May 11 23:11:30 PDT 2020
mravishankar created this revision.
Herald added subscribers: llvm-commits, Kayjukh, frgossen, grosul1, Joonsoo, stephenneuendorffer, liufengdb, aartbik, lucyrfox, mgester, arpith-jacob, nicolasvasilache, antiagainst, shauheen, jpienaar, rriddle, mehdi_amini.
Herald added a reviewer: nicolasvasilache.
Herald added a project: LLVM.
mravishankar added a child revision: D79765: [mlir][Linalg] Add folders and canonicalizers for linalg.reshape/linalg.tensor_reshape operations..
This is only valid if the source tensors (result tensor) is static
shaped with all unit-extents when the reshape is collapsing
(expanding) dimensions.
Depends On D79763 <https://reviews.llvm.org/D79763>
Repository:
rG LLVM Github Monorepo
https://reviews.llvm.org/D79764
Files:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/llvm.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
Index: mlir/test/Dialect/Linalg/roundtrip.mlir
===================================================================
--- mlir/test/Dialect/Linalg/roundtrip.mlir
+++ mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -633,3 +633,26 @@
// CHECK: linalg.batch_matmul
// CHECK: linalg.batch_matmul
+// -----
+
+func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>) -> (tensor<f32>, tensor<1x1xf32>)
+{
+ %0 = linalg.tensor_reshape %arg0 [] : tensor<1x1xf32> into tensor<f32>
+ %1 = linalg.tensor_reshape %0 [] : tensor<f32> into tensor<1x1xf32>
+ return %0, %1 : tensor<f32>, tensor<1x1xf32>
+}
+// CHECK-LABEL: func @tensor_reshape_zero_dim
+// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
+// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
+
+// -----
+
+func @memref_reshape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>) -> (memref<f32>, memref<1x1xf32>)
+{
+ %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
+ %1 = linalg.reshape %0 [] : memref<f32> into memref<1x1xf32>
+ return %0, %1 : memref<f32>, memref<1x1xf32>
+}
+// CHECK-LABEL: func @memref_reshape_zero_dim
+// CHECK: linalg.reshape %{{.*}} [] : memref<1x1xf32> into memref<f32>
+// CHECK: linalg.reshape %{{.*}} [] : memref<f32> into memref<1x1xf32>
Index: mlir/test/Dialect/Linalg/llvm.mlir
===================================================================
--- mlir/test/Dialect/Linalg/llvm.mlir
+++ mlir/test/Dialect/Linalg/llvm.mlir
@@ -273,3 +273,32 @@
// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+
+func @reshape_zero_dim(%arg0 : memref<1x1xf32>) {
+ %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
+ %1 = linalg.reshape %0 [] : memref<f32> into memref<1x1xf32>
+ return
+}
+// CHECK-LABEL: func @reshape_zero_dim
+// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }">
+// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
+// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
+// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64 }">
+// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64 }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
Index: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
===================================================================
--- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -553,11 +553,21 @@
std::swap(expandedRank, collapsedRank);
std::swap(expandedType, collapsedType);
}
- if (expandedRank == 0 || collapsedRank == 0)
+ if (expandedRank == 0)
return op.emitOpError("expected non-zero memref ranks");
if (expandedRank == collapsedRank)
return op.emitOpError("expected to collapse or expand dims");
+ if (collapsedRank == 0) {
+ // If collapsed rank is 0, then expanded type must be static shaped and of
+ // sizes 1.
+ if (llvm::any_of(expandedType.getShape(),
+ [](int64_t dim) -> bool { return dim != 1; }))
+ return op.emitOpError(
+ "invalid to reshape tensor/memref with non-unit extent dimensions to "
+ "zero-rank tensor/memref");
+ return success();
+ }
if (collapsedRank != op.reassociation().size())
return op.emitOpError("expected rank of the collapsed type(")
<< collapsedRank << ") to be the number of reassociation maps("
-------------- next part --------------
A non-text attachment was scrubbed...
Name: D79764.263349.patch
Type: text/x-patch
Size: 5304 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20200512/50d126a1/attachment-0001.bin>
More information about the llvm-commits
mailing list