[Mlir-commits] [mlir] cdf7b66 - [mlir][Linalg] Fix incorrect logic in deciding when to fuse reshapes by linearization.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 2 11:17:02 PDT 2021


Author: MaheshRavishankar
Date: 2021-07-02T11:16:21-07:00
New Revision: cdf7b661c24d037461492544996925dd5257911b

URL: https://github.com/llvm/llvm-project/commit/cdf7b661c24d037461492544996925dd5257911b
DIFF: https://github.com/llvm/llvm-project/commit/cdf7b661c24d037461492544996925dd5257911b.diff

LOG: [mlir][Linalg] Fix incorrect logic in deciding when to fuse reshapes by linearization.

Fusion by linearization should not happen when
- The reshape is expanding and it is a consumer
- The reshape is collapsing and is a producer.

The bug introduced in this logic by some recent refactoring resulted
in a crash.
To enforce this (negetive) use case, add a test that reproduces the
error and verifies the fix.

Differential Revision: https://reviews.llvm.org/D104970

Added: 
    mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index c638294b1210..fc669b270576 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -75,6 +75,12 @@ def LinalgFoldReshapeOpsByLinearization :
   let summary = "Fold TensorReshapeOps with generic/indexed generic ops by "
                 "linearization";
   let constructor = "mlir::createFoldReshapeOpsByLinearizationPass()";
+  let options = [
+    Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes",
+           "bool", /*default=*/"false",
+           "Allow fusing linalg.tensor_reshape ops that performs unit "
+           "dimension collapsing">
+  ];
   let dependentDialects = ["AffineDialect", "memref::MemRefDialect"];
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index f577f00c8b0f..20f13bd45107 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -320,27 +320,27 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
 /// %0 = op ... : tensor<?x?x4x5xf32>
 /// with output index_map
 ///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
+template <typename TensorReshapeOp>
 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
-                                        ArrayRef<int64_t> sourceShape,
-                                        ArrayRef<AffineMap> reassociationMaps) {
+                                        TensorReshapeOp reshapeOp) {
+  constexpr bool isExpanding =
+      std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value;
+  ArrayRef<int64_t> sourceShape =
+      (isExpanding ? reshapeOp.getResultType().getShape()
+                   : reshapeOp.getSrcType().getShape());
   SmallVector<AffineExpr> resultExprs;
-  resultExprs.reserve(reassociationMaps.size());
   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
   MLIRContext *context = sourceMap.getContext();
 
   // Compute the result exprs based on the reassociation maps.
-  for (AffineMap map : reassociationMaps) {
-    ArrayRef<AffineExpr> collapsedDims = map.getResults();
+  for (auto &indices : reshapeOp.getReassociationIndices()) {
     // Assume that they are in-order and contiguous (already checked in
     // verifier).
-    assert(!collapsedDims.empty());
-    unsigned startDim =
-        collapsedDims.front().cast<AffineDimExpr>().getPosition();
+    assert(!indices.empty());
     SmallVector<int64_t> sizes;
     SmallVector<AffineExpr> dimExprs;
-    for (auto en :
-         llvm::zip(sourceShape.slice(startDim, collapsedDims.size()),
-                   sourceExprs.slice(startDim, collapsedDims.size()))) {
+    for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
+                             sourceExprs.slice(indices[0], indices.size()))) {
       if (std::get<0>(en) == 1)
         continue;
       sizes.push_back(std::get<0>(en));
@@ -359,7 +359,7 @@ static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
 // divs in the indexing maps of the fused op which would make it non-invertible.
 static bool isTensorReshapeOpFoldableByLinearization(
     TensorExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
-  if (!asProducer && expandOp.getResultType().hasStaticShape())
+  if (!asProducer)
     return false;
   return useIndexMap.isPermutation();
 }
@@ -368,23 +368,26 @@ static bool isTensorReshapeOpFoldableByLinearization(
 // consumer).
 static bool isTensorReshapeOpFoldableByLinearization(
     TensorCollapseShapeOp collapseOp, AffineMap useIndexMap, bool asProducer) {
-  if (asProducer && collapseOp.getSrcType().hasStaticShape())
+  if (asProducer)
     return false;
   return useIndexMap.isPermutation();
 }
 
 /// Check if the reshape operation is only expansion into/collapsing of
 /// unit-dimension.
-static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
-                                   ArrayRef<AffineMap> reassociation) {
-  for (auto &map : reassociation) {
+template <typename TensorReshapeOp>
+static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
+  constexpr bool isExpanding =
+      std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value;
+  ArrayRef<int64_t> expandedShape =
+      (isExpanding ? reshapeOp.getResultType().getShape()
+                   : reshapeOp.getSrcType().getShape());
+  for (auto &indices : reshapeOp.getReassociationIndices()) {
     unsigned numUnitDims = 0;
-    for (AffineExpr expr : map.getResults()) {
-      unsigned position = expr.cast<AffineDimExpr>().getPosition();
+    for (int64_t position : indices)
       if (expandedShape[position] == 1)
         numUnitDims++;
-    }
-    if (numUnitDims != map.getNumResults() - 1)
+    if (numUnitDims != indices.size() - 1)
       return false;
   }
   return true;
@@ -818,14 +821,10 @@ struct FoldProducerReshapeOpByLinearization
       if (!reshapeOp)
         continue;
 
-      RankedTensorType returnType = reshapeOp.getResultType();
-
       if (!isTensorReshapeOpFoldableByLinearization(
               reshapeOp, genericOp.getTiedIndexingMap(en.value()),
               /*asProducer =*/true) ||
-          (foldUnitDimReshapesOnly &&
-           !isUnitDimExpansionOnly(returnType.getShape(),
-                                   reshapeOp.getReassociationMaps())))
+          (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
         continue;
 
       // Compute the fused operands list,
@@ -842,8 +841,10 @@ struct FoldProducerReshapeOpByLinearization
       auto invMap = inversePermutation(fusedIndexMaps[en.index()]);
 
       // Compute the indexing map to use for the result of the producer.
-      AffineMap modifiedMap = linearizeCollapsedDims(
-          invMap, returnType.getShape(), reshapeOp.getReassociationMaps());
+      AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp);
+      // The modified map cannot have symbols.
+      if (modifiedMap.getNumSymbols())
+        return failure();
       for (AffineExpr expr : modifiedMap.getResults()) {
         if (!expr.isPureAffine())
           return failure();
@@ -1081,9 +1082,7 @@ struct FoldConsumerReshapeOpByLinearization
             reshapeOp,
             producer.getTiedIndexingMap(producer.getOutputOperand(0)),
             /*asProducer =*/false) ||
-        (foldUnitDimReshapesOnly &&
-         !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
-                                 reshapeOp.getReassociationMaps())))
+        (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
       return failure();
     // The indexing_maps for the operands of the fused operation are same as
     // those for the operands of the producer.
@@ -1093,9 +1092,7 @@ struct FoldConsumerReshapeOpByLinearization
         producer.getTiedIndexingMap(producer.getOutputOperand(0)));
 
     // Compute the indexing map to use for the operand of the producer.
-    AffineMap modifiedMap =
-        linearizeCollapsedDims(invMap, reshapeOp.getSrcType().getShape(),
-                               reshapeOp.getReassociationMaps());
+    AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp);
     for (AffineExpr expr : modifiedMap.getResults()) {
       if (!expr.isPureAffine()) {
         return rewriter.notifyMatchFailure(
@@ -1144,8 +1141,7 @@ struct FoldReshapeWithGenericOpByExpansion
     if (!producer || producer.getNumOutputs() != 1 ||
         !isFusableWithReshapeByDimExpansion(producer,
                                             producer.getOutputOperand(0)) ||
-        isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
-                               reshapeOp.getReassociationMaps()))
+        isUnitDimExpansionOnly(reshapeOp))
       return failure();
     Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
         producer, reshapeOp, producer.getOutputOperand(0), rewriter);
@@ -1248,12 +1244,10 @@ bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
                                       const OpOperand &consumer) {
   auto expandShapeOp = producer.getDefiningOp<linalg::TensorExpandShapeOp>();
   if (expandShapeOp)
-    return !isUnitDimExpansionOnly(expandShapeOp.getSrcType().getShape(),
-                                   expandShapeOp.getReassociationMaps());
+    return !isUnitDimExpansionOnly(expandShapeOp);
   auto collapseShapeOp =
       producer.getDefiningOp<linalg::TensorCollapseShapeOp>();
-  return !isUnitDimExpansionOnly(collapseShapeOp.getSrcType().getShape(),
-                                 collapseShapeOp.getReassociationMaps());
+  return !isUnitDimExpansionOnly(collapseShapeOp);
 }
 
 namespace {
@@ -1312,6 +1306,9 @@ struct FoldReshapeOpsByLinearizationPass
     Operation *op = getOperation();
     RewritePatternSet patterns(op->getContext());
     populateFoldReshapeOpsByLinearizationPatterns(patterns);
+    if (allowFoldingUnitDimReshapes) {
+      populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
+    }
     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
   }
 };

diff  --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir
new file mode 100644
index 000000000000..a4a27b5e747b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir-opt -linalg-fold-reshape-ops-by-linearization=allow-folding-unit-dim-reshapes -split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func @do_not_fold1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?x1xf32>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+  %3 = linalg.generic {
+      indexing_maps = [#map, #map, #map],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%2 : tensor<?x?xf32>) {
+      ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
+        %4 = addf %arg2, %arg3 : f32
+        linalg.yield %4 : f32
+      } -> tensor<?x?xf32>
+  %4 = linalg.tensor_expand_shape %3 [[0], [1, 2]] : tensor<?x?xf32> into tensor<?x?x1xf32>
+  return %4 : tensor<?x?x1xf32>
+}
+// CHECK-LABEL: func @do_not_fold1
+//       CHECK: %[[VAL:.+]] = linalg.generic
+//       CHECK: linalg.tensor_expand_shape %[[VAL]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func @do_not_fold2(%arg0 : tensor<?x?x1xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2]] : tensor<?x?x1xf32> into tensor<?x?xf32>
+  %1 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
+  %2 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+  %3 = linalg.init_tensor [%1, %2] : tensor<?x?xf32>
+  %4 = linalg.generic {
+      indexing_maps = [#map, #map, #map],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%3 : tensor<?x?xf32>) {
+      ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
+        %4 = addf %arg2, %arg3 : f32
+        linalg.yield %4 : f32
+      } -> tensor<?x?xf32>
+  return %4 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @do_not_fold2
+//       CHECK: %[[VAL:.+]] = linalg.tensor_collapse_shape
+//       CHECK: linalg.generic
+//  CHECK-SAME:   ins(%[[VAL]], %{{.+}} : tensor<?x?xf32>, tensor<?x?xf32>)


        


More information about the Mlir-commits mailing list