[Mlir-commits] [mlir] d2a9569 - [mlir][Linalg] Allow reshapes to collapse to a zero-rank tensor.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 12 23:03:59 PDT 2020


Author: MaheshRavishankar
Date: 2020-05-12T23:03:25-07:00
New Revision: d2a9569850166083638a8fb88f7dae1c1b62a926

URL: https://github.com/llvm/llvm-project/commit/d2a9569850166083638a8fb88f7dae1c1b62a926
DIFF: https://github.com/llvm/llvm-project/commit/d2a9569850166083638a8fb88f7dae1c1b62a926.diff

LOG: [mlir][Linalg] Allow reshapes to collapse to a zero-rank tensor.

This is only valid if the source tensors (result tensor) is static
shaped with all unit-extents when the reshape is collapsing
(expanding) dimensions.

Differential Revision: https://reviews.llvm.org/D79764

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/llvm.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index af1bc121bf3b..874bda002a59 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -111,6 +111,13 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape">,
     rank to obtain the memref with the smaller rank. In the case of a dimension
     expansion, the reassociation maps can be interpreted as inverse maps.
 
+    The result memref type of a reshape when dimensions are collapsed
+    (operand memref type when dimensions are expanded) can be
+    zero-ranked if the operand memref type (or the result memref type
+    when dimensions are expanded) is statically shaped with all
+    dimensions being unit extent. In such cases the reassociation map
+    is empty.
+
     Examples:
 
     ```mlir
@@ -152,6 +159,13 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
     rank to obtain the tensor with the smaller rank. In the case of a dimension
     expansion, the reassociation maps can be interpreted as inverse maps.
 
+    The result tensor type of a reshape when dimensions are collapsed
+    (operand tensor type when dimensions are expanded) can be
+    zero-ranked if the operand tensor type (or the result tensor type
+    when dimensions are expanded) is statically shaped with all
+    dimensions being unit extent. In such cases the reassociation map
+    is empty.
+
     Examples:
 
     ```mlir

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index fc2353e4087e..4b1b8cde639e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -438,11 +438,21 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T &expandedType,
     std::swap(expandedRank, collapsedRank);
     std::swap(expandedType, collapsedType);
   }
-  if (expandedRank == 0 || collapsedRank == 0)
+  if (expandedRank == 0)
     return op.emitOpError("expected non-zero memref ranks");
   if (expandedRank == collapsedRank)
     return op.emitOpError("expected to collapse or expand dims");
 
+  if (collapsedRank == 0) {
+    // If collapsed rank is 0, then expanded type must be static shaped and of
+    // sizes 1.
+    if (llvm::any_of(expandedType.getShape(),
+                     [](int64_t dim) -> bool { return dim != 1; }))
+      return op.emitOpError(
+          "invalid to reshape tensor/memref with non-unit extent dimensions to "
+          "zero-rank tensor/memref");
+    return success();
+  }
   if (collapsedRank != op.reassociation().size())
     return op.emitOpError("expected rank of the collapsed type(")
            << collapsedRank << ") to be the number of reassociation maps("

diff  --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index 290e1a2fe4d7..84b3bcb66940 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -273,3 +273,32 @@ func @reshape_static(%arg0: memref<3x4x5xf32>) {
 //       CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
 //       CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
 //       CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+
+func @reshape_zero_dim(%arg0 : memref<1x1xf32>) {
+  %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
+  %1 = linalg.reshape %0 [] : memref<f32> into memref<1x1xf32>
+  return
+}
+// CHECK-LABEL: func @reshape_zero_dim
+//       CHECK:   llvm.mlir.undef : !llvm<"{ float*, float*, i64 }">
+//       CHECK:   llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
+//       CHECK:   llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
+//       CHECK:   llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64 }">
+//       CHECK:   llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64 }">
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+//       CHECK:   llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 9b85b2874658..5237db79c42e 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -630,3 +630,26 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x
 //       CHECK:   linalg.batch_matmul
 //       CHECK:   linalg.batch_matmul
 
+// -----
+
+func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>) -> (tensor<f32>, tensor<1x1xf32>)
+{
+  %0 = linalg.tensor_reshape %arg0 [] : tensor<1x1xf32> into tensor<f32>
+  %1 = linalg.tensor_reshape %0 [] : tensor<f32> into tensor<1x1xf32>
+  return %0, %1 : tensor<f32>, tensor<1x1xf32>
+}
+// CHECK-LABEL: func @tensor_reshape_zero_dim
+//       CHECK:   linalg.tensor_reshape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
+//       CHECK:   linalg.tensor_reshape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
+
+// -----
+
+func @memref_reshape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>) -> (memref<f32>, memref<1x1xf32>)
+{
+  %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
+  %1 = linalg.reshape %0 [] : memref<f32> into memref<1x1xf32>
+  return %0, %1 : memref<f32>, memref<1x1xf32>
+}
+// CHECK-LABEL: func @memref_reshape_zero_dim
+//       CHECK:   linalg.reshape %{{.*}} [] : memref<1x1xf32> into memref<f32>
+//       CHECK:   linalg.reshape %{{.*}} [] : memref<f32> into memref<1x1xf32>


        


More information about the Mlir-commits mailing list