[Mlir-commits] [mlir] 4317a3d - [mlir][Linalg] Disable fusion of reshape with elementwise ops for purely dynamic cases.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 6 10:32:35 PST 2022


Author: MaheshRavishankar
Date: 2022-01-06T10:32:24-08:00
New Revision: 4317a3dfad52ee2830d686c6f3a9ef8011f7e6ad

URL: https://github.com/llvm/llvm-project/commit/4317a3dfad52ee2830d686c6f3a9ef8011f7e6ad
DIFF: https://github.com/llvm/llvm-project/commit/4317a3dfad52ee2830d686c6f3a9ef8011f7e6ad.diff

LOG: [mlir][Linalg] Disable fusion of reshape with elementwise ops for purely dynamic cases.

`tensor.collapse_shape` op when fused with a consumer elementwise
`linalg.generic` operation results in creation of tensor.expand_shape
ops. In purely dynamic cases this can end up with a dynamic dimensions
being expanded to more than one dynamic dimension. This is disallowed
by the semantics of `tensor.expand_shape` operation. (While the
transformation is itself correct, its a gap in the specification of
`tensor.expand_shape` that is the issue). So disallow fusions which
result in such a pattern.

Differential Revision: https://reviews.llvm.org/D116703

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/Dialect/Linalg/reshape_fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6fd3927c80cac..8f7c331597bce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -524,6 +524,7 @@ class ExpansionInfo {
   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
                         ArrayRef<AffineMap> reassociationMaps,
                         ArrayRef<int64_t> expandedShape,
+                        ArrayRef<int64_t> collapsedShape,
                         PatternRewriter &rewriter);
   unsigned getOrigOpNumDims() const { return reassociation.size(); }
   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
@@ -533,6 +534,7 @@ class ExpansionInfo {
   ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
     return expandedShapeMap[i];
   }
+  ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
 
 private:
   /// Reassociation from the dimensions in the original operation to the
@@ -541,6 +543,8 @@ class ExpansionInfo {
   /// Mapping from extent of loops in the original operation, to the extent of
   /// loops in the expanded operation.
   SmallVector<SmallVector<int64_t>> expandedShapeMap;
+  /// Extent of the loop in the original operation.
+  SmallVector<int64_t> originalLoopExtent;
   unsigned expandedOpNumDims;
 };
 } // namespace
@@ -549,6 +553,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
                                      OpOperand *fusableOpOperand,
                                      ArrayRef<AffineMap> reassociationMaps,
                                      ArrayRef<int64_t> expandedShape,
+                                     ArrayRef<int64_t> collapsedShape,
                                      PatternRewriter &rewriter) {
   if (reassociationMaps.empty())
     return failure();
@@ -558,6 +563,8 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
       linalgOp.getStaticLoopRanges();
   if (!originalLoopRange)
     return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
+  originalLoopExtent.assign(originalLoopRange->begin(),
+                            originalLoopRange->end());
 
   reassociation.clear();
   expandedShapeMap.clear();
@@ -576,7 +583,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
   // The remaining dimensions remain the same.
   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
     if (expandedShapeMap[i].empty())
-      expandedShapeMap[i] = {(*originalLoopRange)[i]};
+      expandedShapeMap[i] = {originalLoopExtent[i]};
 
   // Compute reassociation map from the original op to the expanded op.
   unsigned sum = 0;
@@ -601,6 +608,30 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
 LogicalResult isGenericOpExpandable(GenericOp genericOp,
                                     const ExpansionInfo &expansionInfo,
                                     PatternRewriter &rewriter) {
+  // Current reshape only supports expansion of a dynamic dim when only one of
+  // the expanded dims are dynamic.
+  for (auto originalShape : llvm::enumerate(expansionInfo.getOriginalShape()))
+    if (ShapedType::isDynamic(originalShape.value())) {
+      // All but one of the expanded dims must be static.
+      bool foundDynamicExpandedDim = false;
+      for (auto expandedShape :
+           expansionInfo.getExpandedShapeOfDim(originalShape.index())) {
+        if (ShapedType::isDynamic(expandedShape)) {
+          if (foundDynamicExpandedDim) {
+            return rewriter.notifyMatchFailure(
+                genericOp,
+                "cannot expanded dynamic dims into multiple dynamic dims");
+          }
+          foundDynamicExpandedDim = true;
+        }
+      }
+      if (!foundDynamicExpandedDim) {
+        return rewriter.notifyMatchFailure(
+            genericOp, "dynamic dim expansion needs at least one dynamic dim "
+                       "in result shape");
+      }
+    }
+
   if (!genericOp.hasIndexSemantics())
     return success();
   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
@@ -731,13 +762,16 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
   RankedTensorType expandedType = isExpanding
                                       ? expandingReshapeOp.getResultType()
                                       : collapsingReshapeOp.getSrcType();
+  RankedTensorType collapsedType = isExpanding
+                                       ? expandingReshapeOp.getSrcType()
+                                       : collapsingReshapeOp.getResultType();
 
   ExpansionInfo expansionInfo;
   if (failed(expansionInfo.compute(
           genericOp, fusableOpOperand,
           isExpanding ? expandingReshapeOp.getReassociationMaps()
                       : collapsingReshapeOp.getReassociationMaps(),
-          expandedType.getShape(), rewriter)))
+          expandedType.getShape(), collapsedType.getShape(), rewriter)))
     return llvm::None;
 
   if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))

diff  --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 324aa2809ce37..9582a4bbafc43 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -507,3 +507,26 @@ func @unit_dim_reshape_expansion_full
 //    FOLDUNITDIM-SAME:     ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
 //    FOLDUNITDIM-SAME:     outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>)
 
+// -----
+
+func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
+  %1 = tensor.dim %0, %c0 : tensor<?xf32>
+  %2 = linalg.init_tensor [%1] : tensor<?xf32>
+  %3 = linalg.generic {
+    indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+    iterator_types = ["parallel"]}
+    ins(%0 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {
+      ^bb0(%arg1 : f32, %arg2: f32):
+        %4 = arith.addf %arg1, %arg1 : f32
+        linalg.yield %4 : f32
+    } -> tensor<?xf32>
+  return %3 : tensor<?xf32>
+}
+//      CHECK: func @no_fuse_dynamic_dims
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+//      CHECK:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
+//      CHECK:   %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:       ins(%[[RESHAPE]] : tensor<?xf32>)
+//      CHECK:   return %[[GENERIC]]


        


More information about the Mlir-commits mailing list