[PATCH] D79764: [mlir][Linalg] Allow reshapes to collapse to a zero-rank tensor.

Mahesh Ravishankar via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Mon May 11 23:11:30 PDT 2020


mravishankar created this revision.
Herald added subscribers: llvm-commits, Kayjukh, frgossen, grosul1, Joonsoo, stephenneuendorffer, liufengdb, aartbik, lucyrfox, mgester, arpith-jacob, nicolasvasilache, antiagainst, shauheen, jpienaar, rriddle, mehdi_amini.
Herald added a reviewer: nicolasvasilache.
Herald added a project: LLVM.
mravishankar added a child revision: D79765: [mlir][Linalg] Add folders and canonicalizers for linalg.reshape/linalg.tensor_reshape operations..

This is only valid if the source tensors (result tensor) is static
shaped with all unit-extents when the reshape is collapsing
(expanding) dimensions.

Depends On D79763 <https://reviews.llvm.org/D79763>


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D79764

Files:
  mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
  mlir/test/Dialect/Linalg/llvm.mlir
  mlir/test/Dialect/Linalg/roundtrip.mlir


Index: mlir/test/Dialect/Linalg/roundtrip.mlir
===================================================================
--- mlir/test/Dialect/Linalg/roundtrip.mlir
+++ mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -633,3 +633,26 @@
 //       CHECK:   linalg.batch_matmul
 //       CHECK:   linalg.batch_matmul
 
+// -----
+
+func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>) -> (tensor<f32>, tensor<1x1xf32>)
+{
+  %0 = linalg.tensor_reshape %arg0 [] : tensor<1x1xf32> into tensor<f32>
+  %1 = linalg.tensor_reshape %0 [] : tensor<f32> into tensor<1x1xf32>
+  return %0, %1 : tensor<f32>, tensor<1x1xf32>
+}
+// CHECK-LABEL: func @tensor_reshape_zero_dim
+//       CHECK:   linalg.tensor_reshape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
+//       CHECK:   linalg.tensor_reshape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
+
+// -----
+
+func @memref_reshape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>) -> (memref<f32>, memref<1x1xf32>)
+{
+  %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
+  %1 = linalg.reshape %0 [] : memref<f32> into memref<1x1xf32>
+  return %0, %1 : memref<f32>, memref<1x1xf32>
+}
+// CHECK-LABEL: func @memref_reshape_zero_dim
+//       CHECK:   linalg.reshape %{{.*}} [] : memref<1x1xf32> into memref<f32>
+//       CHECK:   linalg.reshape %{{.*}} [] : memref<f32> into memref<1x1xf32>
Index: mlir/test/Dialect/Linalg/llvm.mlir
===================================================================
--- mlir/test/Dialect/Linalg/llvm.mlir
+++ mlir/test/Dialect/Linalg/llvm.mlir
@@ -273,3 +273,32 @@
 //       CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
 //       CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
 //       CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+
+func @reshape_zero_dim(%arg0 : memref<1x1xf32>) {
+  %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
+  %1 = linalg.reshape %0 [] : memref<f32> into memref<1x1xf32>
+  return
+}
+// CHECK-LABEL: func @reshape_zero_dim
+//       CHECK:   llvm.mlir.undef : !llvm<"{ float*, float*, i64 }">
+//       CHECK:   llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
+//       CHECK:   llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
+//       CHECK:   llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64 }">
+//       CHECK:   llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64 }">
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
Index: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
===================================================================
--- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -553,11 +553,21 @@
     std::swap(expandedRank, collapsedRank);
     std::swap(expandedType, collapsedType);
   }
-  if (expandedRank == 0 || collapsedRank == 0)
+  if (expandedRank == 0)
     return op.emitOpError("expected non-zero memref ranks");
   if (expandedRank == collapsedRank)
     return op.emitOpError("expected to collapse or expand dims");
 
+  if (collapsedRank == 0) {
+    // If collapsed rank is 0, then expanded type must be static shaped and of
+    // sizes 1.
+    if (llvm::any_of(expandedType.getShape(),
+                     [](int64_t dim) -> bool { return dim != 1; }))
+      return op.emitOpError(
+          "invalid to reshape tensor/memref with non-unit extent dimensions to "
+          "zero-rank tensor/memref");
+    return success();
+  }
   if (collapsedRank != op.reassociation().size())
     return op.emitOpError("expected rank of the collapsed type(")
            << collapsedRank << ") to be the number of reassociation maps("


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D79764.263349.patch
Type: text/x-patch
Size: 5304 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20200512/50d126a1/attachment-0001.bin>


More information about the llvm-commits mailing list