[Mlir-commits] [mlir] a43f7d6 - [mlir][tensor] Extend reshape utils.

Stephan Herhut llvmlistbot at llvm.org
Fri Feb 18 00:58:09 PST 2022


Author: Stephan Herhut
Date: 2022-02-18T09:57:39+01:00
New Revision: a43f7d6d76984ddae4a5e5e0bebf82ee2edebabb

URL: https://github.com/llvm/llvm-project/commit/a43f7d6d76984ddae4a5e5e0bebf82ee2edebabb
DIFF: https://github.com/llvm/llvm-project/commit/a43f7d6d76984ddae4a5e5e0bebf82ee2edebabb.diff

LOG: [mlir][tensor] Extend reshape utils.

This change changes the handling of trailing dimensions with unknown
extent. Users of the changessociationIndicesForReshape helper should
see benefits when transforming reshape like operations into
expand/collapse pairs if the higher-rank type has trailing unknown
dimensions.

The motivating example is a reshape from tensor<16x1x?xi32> to
tensor<16xi32> that can be modeled as collapsing the three dimensions.

Differential Revision: https://reviews.llvm.org/D119730

Added: 
    

Modified: 
    mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index fd509621015d2..17f449e489e38 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -35,13 +35,12 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
   int64_t prodOfCollapsedDims = 1;
   while (sourceDim < sourceShape.size()) {
     unsigned targetDim = reassociationMap.size();
+    // If we have mapped all the target dimensions stop and handle the remaining
+    // tail of size-1 dimensions explictly.
+    if (targetDim == targetType.getRank())
+      break;
 
-    // If all the dimensions of the targetShape are exhausted, then the
-    // remaining dims in the source shape must be all 1s. So for such cases, set
-    // 1 as the target shape. The actual reassociation indices will be handled
-    // later.
-    int64_t currTargetShape =
-        (targetDim < targetType.getRank() ? targetShape[targetDim] : 1);
+    int64_t currTargetShape = targetShape[targetDim];
     while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
            prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
            sourceDim < sourceShape.size()) {
@@ -69,25 +68,23 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
       return llvm::None;
 
     currIndices.push_back(sourceDim++);
-    // If the reassociation is empty but the currIndices is not, this by
-    // definition is folding unit-dimensions with the result being scalar type.
-    // So only append the `currIndices` if reassociation map is not empty.
-    if (targetDim == targetShape.size()) {
-      while (sourceDim < sourceShape.size())
-        currIndices.push_back(sourceDim++);
-      if (!reassociationMap.empty() && !currIndices.empty())
-        reassociationMap.back().append(currIndices.begin(), currIndices.end());
-      // Break out of the loops. We should be done here.
-      break;
-    }
     reassociationMap.emplace_back(ReassociationIndices{});
     std::swap(reassociationMap.back(), currIndices);
     prodOfCollapsedDims = 1;
   }
-  // All the dimensions in the two shapes must have been processed.
-  if (reassociationMap.size() != targetShape.size() ||
-      sourceDim != sourceShape.size())
+  // All the dimensions in the target must have been processed.
+  if (reassociationMap.size() != targetShape.size())
     return llvm::None;
+  // Process any remaining entries in the source shape. They all need to be
+  // 1 or dynamic.
+  for (; sourceDim < sourceShape.size(); sourceDim++) {
+    if (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
+        sourceShape[sourceDim] != 1)
+      return llvm::None;
+    // The map is empty when the target type is a scalar.
+    if (!reassociationMap.empty())
+      reassociationMap.back().push_back(sourceDim);
+  }
   return reassociationMap;
 }
 

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 4e4bbb8a12672..ce3db8d6039c2 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -879,16 +879,17 @@ func @fold_reshape_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
 
 // -----
 
-func @no_fold_reshapes(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?xf32> {
+func @fold_reshapes_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?xf32> {
   %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
       : tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
   %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]]
       : tensor<?x?x1x?xf32> into tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
-// CHECK-LABEL: func @no_fold_reshapes
-//       CHECK:   tensor.expand_shape
-//       CHECK:   tensor.collapse_shape
+// CHECK-LABEL: func @fold_reshapes_unit_dims_in_middle
+//  CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>
+//       CHECK: tensor.collapse_shape %[[ARG]] {{\[}}[0], [1, 2]]
+//  CHECK-SAME:   tensor<?x?x?xf32> into tensor<?x?xf32>
 
 // -----
 


        


More information about the Mlir-commits mailing list