[Mlir-commits] [mlir] ff5de8a - [linalg][fusion] Disallow fusion when it would create an invalid expand_shape

Benjamin Kramer llvmlistbot at llvm.org
Tue Jan 18 14:44:40 PST 2022


Author: Benjamin Kramer
Date: 2022-01-18T23:44:14+01:00
New Revision: ff5de8a9e0e5cb7f82b945486b784407f6aab8fe

URL: https://github.com/llvm/llvm-project/commit/ff5de8a9e0e5cb7f82b945486b784407f6aab8fe
DIFF: https://github.com/llvm/llvm-project/commit/ff5de8a9e0e5cb7f82b945486b784407f6aab8fe.diff

LOG: [linalg][fusion] Disallow fusion when it would create an invalid expand_shape

The input type of a linalg.generic can be less dynamic than its output
type. If this is the case moving a reshape across the generic op would
create invalid IR, as expand_shape cannot expand arbitrary dynamic
dimensions.

Check that the reshape is actually valid before creating the
expand_shape. This exposes the existing verification logic in reshape
utils and removes the incomplete custom implementation in fusion.

Differential Revision: https://reviews.llvm.org/D116600

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
    mlir/test/Dialect/Linalg/reshape_fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 1c42c4b8542e..b2d4cf1e4bff 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -166,47 +166,19 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
 /// 2) if a dimension in the collaped type is dynamic, one and only one of the
 ///    corresponding dimensions in the expanded type should be dynamic. This
 ///    rule is only needed with reshape operations that are expanding.
+LogicalResult reshapeLikeShapesAreCompatible(
+    function_ref<LogicalResult(const Twine &)> emitError,
+    ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
+    ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
+
 template <typename OpTy>
 static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
                                              ShapedType expandedType,
                                              bool isExpandingReshape) {
-  ArrayRef<int64_t> collapsedShape = collapsedType.getShape();
-  ArrayRef<int64_t> expandedShape = expandedType.getShape();
-  unsigned expandedDimStart = 0;
-  for (auto map : llvm::enumerate(op.getReassociationMaps())) {
-    Optional<int64_t> dynamicShape;
-    int64_t linearizedStaticShape = 1;
-    for (auto dim : llvm::enumerate(expandedShape.slice(
-             expandedDimStart, map.value().getNumResults()))) {
-      if (ShapedType::isDynamic(dim.value())) {
-        if (isExpandingReshape && dynamicShape) {
-          return op->emitOpError("invalid to have a single dimension (")
-                 << map.index() << ") expanded into multiple dynamic dims ("
-                 << expandedDimStart + dynamicShape.getValue() << ","
-                 << expandedDimStart + dim.index() << ")";
-        }
-        dynamicShape = dim.index();
-      } else {
-        linearizedStaticShape *= dim.value();
-      }
-    }
-    if (dynamicShape) {
-      if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
-        return op->emitOpError("expected dimension ")
-               << map.index()
-               << " of collapsed type to be dynamic since one or more of the "
-                  "corresponding dimensions in the expanded type is dynamic";
-      }
-    } else {
-      if (collapsedShape[map.index()] != linearizedStaticShape) {
-        return op->emitOpError("expected dimension ")
-               << map.index() << " of collapsed type to be static value of "
-               << linearizedStaticShape << " ";
-      }
-    }
-    expandedDimStart += map.value().getNumResults();
-  }
-  return success();
+  return reshapeLikeShapesAreCompatible(
+      [&](const Twine &msg) { return op->emitOpError(msg); },
+      collapsedType.getShape(), expandedType.getShape(),
+      op.getReassociationIndices(), isExpandingReshape);
 }
 
 /// Pattern to collapse producer/consumer reshape ops that are both collapsing

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index ce636f3d84d3..619908ca5a38 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -608,31 +608,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
 LogicalResult isGenericOpExpandable(GenericOp genericOp,
                                     const ExpansionInfo &expansionInfo,
                                     PatternRewriter &rewriter) {
-  // Current reshape only supports expansion of a dynamic dim when only one of
-  // the expanded dims are dynamic.
-  for (const auto &originalShape :
-       llvm::enumerate(expansionInfo.getOriginalShape()))
-    if (ShapedType::isDynamic(originalShape.value())) {
-      // All but one of the expanded dims must be static.
-      bool foundDynamicExpandedDim = false;
-      for (auto expandedShape :
-           expansionInfo.getExpandedShapeOfDim(originalShape.index())) {
-        if (ShapedType::isDynamic(expandedShape)) {
-          if (foundDynamicExpandedDim) {
-            return rewriter.notifyMatchFailure(
-                genericOp,
-                "cannot expanded dynamic dims into multiple dynamic dims");
-          }
-          foundDynamicExpandedDim = true;
-        }
-      }
-      if (!foundDynamicExpandedDim) {
-        return rewriter.notifyMatchFailure(
-            genericOp, "dynamic dim expansion needs at least one dynamic dim "
-                       "in result shape");
-      }
-    }
-
   if (!genericOp.hasIndexSemantics())
     return success();
   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
@@ -793,13 +768,21 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
     }
     if (genericOp.isInputTensor(opOperand)) {
       AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+      auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
       RankedTensorType expandedOperandType =
-          getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
-                          indexingMap, expansionInfo);
+          getExpandedType(opOperandType, indexingMap, expansionInfo);
       if (expandedOperandType != opOperand->get().getType()) {
         // Reshape the operand to get the right type.
         SmallVector<ReassociationIndices> reassociation =
             getReassociationForExpansion(indexingMap, expansionInfo);
+        if (failed(reshapeLikeShapesAreCompatible(
+                [&](const Twine &msg) {
+                  return rewriter.notifyMatchFailure(genericOp, msg);
+                },
+                opOperandType.getShape(), expandedOperandType.getShape(),
+                reassociation,
+                /*isExpandingReshape=*/true)))
+          return llvm::None;
         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
             genericOp.getLoc(), expandedOperandType, opOperand->get(),
             reassociation));
@@ -813,12 +796,20 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
   SmallVector<Value> outputs;
   for (OpOperand *opOperand : genericOp.getOutputOperands()) {
     AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+    auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
     RankedTensorType expandedOutputType =
-        getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
-                        indexingMap, expansionInfo);
+        getExpandedType(opOperandType, indexingMap, expansionInfo);
     if (expandedOutputType != opOperand->get().getType()) {
       SmallVector<ReassociationIndices> reassociation =
           getReassociationForExpansion(indexingMap, expansionInfo);
+      if (failed(reshapeLikeShapesAreCompatible(
+              [&](const Twine &msg) {
+                return rewriter.notifyMatchFailure(genericOp, msg);
+              },
+              opOperandType.getShape(), expandedOutputType.getShape(),
+              reassociation,
+              /*isExpandingReshape=*/true)))
+        return llvm::None;
       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
           genericOp.getLoc(), expandedOutputType, opOperand->get(),
           reassociation));

diff  --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index b5136273d635..0048abee4194 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -276,3 +276,45 @@ bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
   }
   return true;
 }
+
+LogicalResult mlir::reshapeLikeShapesAreCompatible(
+    function_ref<LogicalResult(const Twine &)> emitError,
+    ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
+    ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
+  unsigned expandedDimStart = 0;
+  for (const auto &map : llvm::enumerate(reassociationMaps)) {
+    Optional<int64_t> dynamicShape;
+    int64_t linearizedStaticShape = 1;
+    for (const auto &dim : llvm::enumerate(
+             expandedShape.slice(expandedDimStart, map.value().size()))) {
+      if (ShapedType::isDynamic(dim.value())) {
+        if (isExpandingReshape && dynamicShape) {
+          return emitError("invalid to have a single dimension (" +
+                           Twine(map.index()) +
+                           ") expanded into multiple dynamic dims (" +
+                           Twine(expandedDimStart + dynamicShape.getValue()) +
+                           "," + Twine(expandedDimStart + dim.index()) + ")");
+        }
+        dynamicShape = dim.index();
+      } else {
+        linearizedStaticShape *= dim.value();
+      }
+    }
+    if (dynamicShape) {
+      if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
+        return emitError(
+            "expected dimension " + Twine(map.index()) +
+            " of collapsed type to be dynamic since one or more of the "
+            "corresponding dimensions in the expanded type is dynamic");
+      }
+    } else {
+      if (collapsedShape[map.index()] != linearizedStaticShape) {
+        return emitError("expected dimension " + Twine(map.index()) +
+                         " of collapsed type to be static value of " +
+                         Twine(linearizedStaticShape));
+      }
+    }
+    expandedDimStart += map.value().size();
+  }
+  return success();
+}

diff  --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 9582a4bbafc4..508726a4e60e 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -530,3 +530,30 @@ func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       ins(%[[RESHAPE]] : tensor<?xf32>)
 //      CHECK:   return %[[GENERIC]]
+
+// -----
+
+func @no_fuse_mismatched_dynamism(%arg0: tensor<1x1xi64>, %arg1: tensor<?xi64>) -> tensor<1xi64> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x1xi64> into tensor<1xi64>
+  %1 = linalg.init_tensor [1] : tensor<1xi64>
+  %2 = linalg.generic
+    {indexing_maps = [affine_map<(d0) -> (d0)>,
+                      affine_map<(d0) -> (d0)>,
+                      affine_map<(d0) -> (d0)>],
+     iterator_types = ["parallel"]}
+    ins(%0, %arg1 : tensor<1xi64>, tensor<?xi64>)
+    outs(%1 : tensor<1xi64>) {
+  ^bb0(%arg4: i64, %arg5: i64, %arg6: i64):  // no predecessors
+    %3 = arith.addi %arg4, %arg5 : i64
+    linalg.yield %3 : i64
+  } -> tensor<1xi64>
+  return %2 : tensor<1xi64>
+}
+
+//      CHECK: func @no_fuse_mismatched_dynamism
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<1x1xi64>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<?xi64>
+//      CHECK:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
+//      CHECK:   %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:       ins(%[[RESHAPE]], %[[ARG1]] : tensor<1xi64>, tensor<?xi64>)
+//      CHECK:   return %[[GENERIC]]


        


More information about the Mlir-commits mailing list