[Mlir-commits] [mlir] e2b3659 - [mlir][Linalg] Break unnecessary dependency through unused `outs` tensor.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 18 22:31:58 PDT 2021


Author: MaheshRavishankar
Date: 2021-05-18T22:31:42-07:00
New Revision: e2b365948b363636624d5c8cf631f075b19351aa

URL: https://github.com/llvm/llvm-project/commit/e2b365948b363636624d5c8cf631f075b19351aa
DIFF: https://github.com/llvm/llvm-project/commit/e2b365948b363636624d5c8cf631f075b19351aa.diff

LOG: [mlir][Linalg] Break unnecessary dependency through unused `outs` tensor.

LinalgOps that are all parallel do not use the value of `outs`
tensor. The semantics is that the `outs` tensor is fully
overwritten. Using anything other than `init_tensor` can add false
dependencies between operations, when the use is just for the shape of
the tensor. Adding a canonicalization to always use `init_tensor` in
such cases, breaks this dependence.

Differential Revision: https://reviews.llvm.org/D102561

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Dialect/Linalg/fusion-tensor.mlir
    mlir/test/Dialect/Linalg/reshape_fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 4002ef07e900..6ee4d765d5f8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -1310,6 +1310,52 @@ struct FoldReshapeOpsByLinearizationPass
   }
 };
 
+/// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
+/// the value of the `outs` operand is not used within the op.  This is only
+/// implemented for `linalg.generic` operations for now, but should hold for all
+/// linalg structured ops.
+struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp op,
+                                PatternRewriter &rewriter) const override {
+    rewriter.startRootUpdate(op);
+    bool modifiedOutput = false;
+    Location loc = op.getLoc();
+    for (OpOperand &opOperand : op.getOutputOpOperands()) {
+      if (!op.payloadUsesValueFromOpOperand(&opOperand)) {
+        Value operandVal = opOperand.get();
+        auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
+        if (!operandType)
+          continue;
+
+        // If outs is already an `init_tensor` operation, nothing to do.
+        auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
+        if (definingOp)
+          continue;
+        modifiedOutput = true;
+        SmallVector<Value> dynamicDims;
+        for (auto dim : llvm::enumerate(operandType.getShape())) {
+          if (dim.value() != ShapedType::kDynamicSize)
+            continue;
+          dynamicDims.push_back(rewriter.createOrFold<memref::DimOp>(
+              loc, operandVal, dim.index()));
+        }
+        Value initTensor = rewriter.create<InitTensorOp>(
+            loc, dynamicDims, operandType.getShape(),
+            operandType.getElementType());
+        op->setOperand(opOperand.getOperandNumber(), initTensor);
+      }
+    }
+    if (!modifiedOutput) {
+      rewriter.cancelRootUpdate(op);
+      return failure();
+    }
+    rewriter.finalizeRootUpdate(op);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
@@ -1339,6 +1385,7 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
   auto *context = patterns.getContext();
   patterns.add<FuseElementwiseOps, FoldSplatConstants>(
       context, options.controlElementwiseOpsFusionFn);
+  patterns.add<RemoveOutsDependency>(context);
   populateFoldReshapeOpsByExpansionPatterns(patterns,
                                             options.controlFoldingReshapesFn);
   AffineApplyOp::getCanonicalizationPatterns(patterns, context);

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 36a1e45839ec..3146b4194f82 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -662,3 +662,39 @@ func @no_fuse_constant_with_reduction() -> tensor<3xf32>
   } -> tensor<3xf32>
   return %result : tensor<3xf32>
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#trait = {
+  indexing_maps = [#map, #map],
+  iterator_types = ["parallel", "parallel"]
+}
+func @break_outs_dependency(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+{
+  %0 = linalg.generic #trait ins(%arg0 : tensor<?x?xf32>) outs(%arg0 : tensor<?x?xf32>) {
+       ^bb0(%arg1 : f32, %arg2 : f32) :
+         %1 = addf %arg1, %arg1 : f32
+         linalg.yield %1 : f32
+       } -> tensor<?x?xf32>
+  %2 = linalg.generic #trait ins(%0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
+       ^bb0(%arg1 : f32, %arg2 : f32) :
+         %3 = mulf %arg1, %arg1 : f32
+         linalg.yield %3 : f32
+       } -> tensor<?x?xf32>
+  return %2 : tensor<?x?xf32>
+}
+//      CHECK: func @break_outs_dependency(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xf32>)
+//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//  CHECK-DAG:   %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
+//      CHECK:   %[[GENERIC1:.+]] = linalg.generic
+// CHECK-SAME:     outs(%[[INIT]] : tensor<?x?xf32>)
+//  CHECK-DAG:   %[[D0:.+]] = memref.dim %[[GENERIC1]], %[[C0]]
+//  CHECK-DAG:   %[[D1:.+]] = memref.dim %[[GENERIC1]], %[[C1]]
+//  CHECK-DAG:   %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
+//      CHECK:   %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME:     outs(%[[INIT]] : tensor<?x?xf32>)

diff  --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 9ff534c8b654..576c83a32dab 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -1,6 +1,5 @@
-// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=false" -split-input-file -verify-each=0 | FileCheck %s
-// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=true" -split-input-file -verify-each=0 | FileCheck %s --check-prefix=FOLDUNITDIM
-
+// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=false" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=true" -split-input-file | FileCheck %s --check-prefix=FOLDUNITDIM
 #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
 #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
 func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
@@ -30,13 +29,11 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
 // CHECK-SAME:     [0], [1, 2], [3]
 //      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
 // CHECK-SAME:     [0], [1], [2, 3]
-//      CHECK:   %[[T2:.+]] = linalg.tensor_reshape %[[T0]]
-// CHECK-SAME:     [0], [1], [2, 3]
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP6]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x?x4x?xf32>, tensor<?x?x?x4xf32>)
-// CHECK-SAME:     outs(%[[T2]] : tensor<?x?x?x4xf32>)
+// CHECK-SAME:     outs(%{{.+}} : tensor<?x?x?x4xf32>)
 //      CHECK:   %[[T4:.+]] = linalg.tensor_reshape %[[T3]]
 // CHECK-SAME:     [0], [1], [2, 3]
 // CHECK-SAME:     tensor<?x?x?x4xf32> into tensor<?x?x?xf32>
@@ -73,13 +70,11 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
 //      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
 // CHECK-SAME:     [0], [1, 2, 3]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x?x5xf32>
-//      CHECK:   %[[T2:.+]] = linalg.tensor_reshape %[[ARG0]]
-// CHECK-SAME:     [0], [1, 2, 3]
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<?x4x?x5xf32>, tensor<?x4x?x5xf32>)
-// CHECK-SAME:     outs(%[[T2]] : tensor<?x4x?x5xf32>)
+// CHECK-SAME:     outs(%{{.+}} : tensor<?x4x?x5xf32>)
 //      CHECK:   return %[[T3]] : tensor<?x4x?x5xf32>
 
 
@@ -115,13 +110,11 @@ func @reshape_as_consumer_permutation
 //      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
 // CHECK-SAME:     [0, 1, 2], [3]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<3x4x?x?xf32>
-//      CHECK:   %[[T2:.+]] = linalg.tensor_reshape %[[ARG0]]
-// CHECK-SAME:     [0, 1], [2], [3, 4, 5]
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<3x4x?x?x2x?xf32>, tensor<3x4x?x?xf32>)
-// CHECK-SAME:     outs(%[[T2]] : tensor<?x2x?x3x4x?xf32>)
+// CHECK-SAME:     outs(%{{.+}} : tensor<?x2x?x3x4x?xf32>)
 //      CHECK:   return %[[T3]] : tensor<?x2x?x3x4x?xf32>
 
 // -----
@@ -417,13 +410,11 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
 //      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
 // CHECK-SAME:     [0, 1, 2], [3]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
-//      CHECK:   %[[T2:.+]] = linalg.tensor_reshape %[[ARG0]]
-// CHECK-SAME:     [0], [1, 2, 3]
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<?x4x5x?xf32>, tensor<?x4x5x?xf32>)
-// CHECK-SAME:     outs(%[[T2]] : tensor<?x?x4x5xf32>)
+// CHECK-SAME:     outs(%{{.+}} : tensor<?x?x4x5xf32>)
 //      CHECK:   return %[[T3]] : tensor<?x?x4x5xf32>
 
 // -----
@@ -501,8 +492,7 @@ func @unit_dim_reshape_expansion_full
 //    FOLDUNITDIM-SAME:   %[[ARG0:.+]]: tensor<1x?x1x2x1x4xf32>
 //    FOLDUNITDIM-SAME:   %[[ARG1:.+]]: tensor<?x2x4xf32>
 //     FOLDUNITDIM-DAG:   %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG1]]
-//     FOLDUNITDIM-DAG:   %[[INIT:.+]] = linalg.init_tensor [1, %{{.+}}, 1, 2, 1, 4]
 //         FOLDUNITDIM:   linalg.generic
 //    FOLDUNITDIM-SAME:     ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
-//    FOLDUNITDIM-SAME:     outs(%[[INIT]] : tensor<1x?x1x2x1x4xf32>)
+//    FOLDUNITDIM-SAME:     outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>)
 


        


More information about the Mlir-commits mailing list