[Mlir-commits] [mlir] eb4efa8 - [mlir][Linalg] Enhance Linalg fusion on generic op and tensor_reshape op.

Hanhan Wang llvmlistbot at llvm.org
Fri Aug 28 01:56:06 PDT 2020


Author: Hanhan Wang
Date: 2020-08-28T01:55:49-07:00
New Revision: eb4efa883212352b2b32ba8aca8525ad17898ed4

URL: https://github.com/llvm/llvm-project/commit/eb4efa883212352b2b32ba8aca8525ad17898ed4
DIFF: https://github.com/llvm/llvm-project/commit/eb4efa883212352b2b32ba8aca8525ad17898ed4.diff

LOG: [mlir][Linalg] Enhance Linalg fusion on generic op and tensor_reshape op.

The tensor_reshape op was only fusible only if it is a collapsing case. Now we
propagate the op to all the operands so there is a further chance to fuse it
with generic op. The pre-conditions are:

1) The producer is not an indexed_generic op.
2) All the shapes of the operands are the same.
3) All the indexing maps are identity.
4) All the loops are parallel loops.
5) The producer has a single user.

It is possible to fuse the ops if the producer is an indexed_generic op. We
still can compute the original indices. E.g., if the reshape op collapses the d0
and d1, we can use DimOp to get the width of d1, and calculate the index
`d0 * width + d1`. Then replace all the uses with it. However, this pattern is
not implemented in the patch.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D86314

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/test/Dialect/Linalg/fusion-tensor.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1b8a22eecc9e..fa45997ae801 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -396,7 +396,8 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
 } // namespace
 
 template <typename ReshapeOpTy>
-static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp) {
+static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
+                                  ArrayRef<Attribute> operands) {
   // Fold producer-consumer reshape ops that where the operand type of the
   // producer is same as the return type of the consumer. This can only be
   // verified if the shapes in question are static.
@@ -406,6 +407,10 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp) {
       reshapeOp.getResultType().hasStaticShape() &&
       reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
     return reshapeSrcOp.src();
+  if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
+    return elements.reshape(
+        reshapeOp.getResult().getType().template cast<ShapedType>());
+  }
   return nullptr;
 }
 
@@ -1175,18 +1180,18 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
 // TODO: Consider making all this boilerplate easy to autogenerate
 // with Tablegen. This seems a desirable property in the context of OpInterfaces
 // where a Linalg "named" op **isa** LinalgOp.
-OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
+OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
-  return foldReshapeOp(*this);
+  return foldReshapeOp(*this, operands);
 }
 OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
   return {};
 }
-OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute>) {
-  return foldReshapeOp(*this);
+OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
+  return foldReshapeOp(*this, operands);
 }
 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
   if (succeeded(foldMemRefCast(*this)))

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 9080a202a824..126228bce25a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -773,6 +773,9 @@ struct FuseTensorReshapeOpAsProducer {
   static LinalgOp fuse(TensorReshapeOp producer, LinalgOp consumer,
                        unsigned consumerIdx, PatternRewriter &rewriter,
                        OperationFolder *folder = nullptr) {
+    if (producer.src().getDefiningOp<ConstantOp>())
+      return nullptr;
+
     if (!isFusible(producer, consumer, consumerIdx))
       return nullptr;
 
@@ -826,20 +829,19 @@ struct FuseTensorReshapeOpAsProducer {
 
 /// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp.
 struct FuseTensorReshapeOpAsConsumer {
-  static bool isFusible(LinalgOp producer, TensorReshapeOp consumer,
-                        unsigned consumerIdx) {
+  static bool isCollapsingAndFusible(LinalgOp producer,
+                                     TensorReshapeOp consumer,
+                                     unsigned consumerIdx) {
     return isa<GenericOp, IndexedGenericOp>(producer.getOperation()) &&
            producer.hasTensorSemantics() &&
            isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0),
                                     /*asProducer=*/false);
   }
 
-  static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer,
-                       unsigned consumerIdx, PatternRewriter &rewriter,
-                       OperationFolder *folder = nullptr) {
-    if (!isFusible(producer, consumer, consumerIdx))
-      return nullptr;
-
+  static LinalgOp fuseCollapsingCase(LinalgOp producer,
+                                     TensorReshapeOp consumer,
+                                     unsigned consumerIdx,
+                                     PatternRewriter &rewriter) {
     // The indexing_maps for the operands of the fused operation are same as
     // those for the operands of the producer.
     SmallVector<AffineMap, 4> fusedIndexMaps =
@@ -882,6 +884,77 @@ struct FuseTensorReshapeOpAsConsumer {
                                fusedRegion.begin());
     return fusedOp;
   }
+
+  static bool isExpandingAndFusible(LinalgOp producer, TensorReshapeOp consumer,
+                                    unsigned consumerIdx) {
+    // Is fusible only if:
+    //   1) The producer is a generic op.
+    //   2) The producer has tensor semantics.
+    //   3) The tensor reshape op is a expanding case.
+    //   4) All the shapes are the same for the generic op.
+    //   5) All the indexing maps in producer are identity.
+    //   6) All the loops in producer are parallel loops.
+    //   7) The producer has a single user.
+    auto types = producer.getInputOutputShapedTypes();
+    assert(!types.empty());
+    return isa<GenericOp>(producer.getOperation()) &&
+           producer.hasTensorSemantics() &&
+           consumer.getSrcType().getRank() <
+               consumer.getResultType().getRank() &&
+           std::equal(types.begin() + 1, types.end(), types.begin()) &&
+           llvm::all_of(producer.getIndexingMaps(),
+                        [](AffineMap map) { return map.isIdentity(); }) &&
+           llvm::all_of(producer.iterator_types(),
+                        [](Attribute attr) {
+                          return attr.cast<StringAttr>().getValue() ==
+                                 getParallelIteratorTypeName();
+                        }) &&
+           producer.getOperation()->hasOneUse();
+  }
+
+  static LinalgOp fuseExpandingCase(LinalgOp producer, TensorReshapeOp consumer,
+                                    unsigned consumerIdx,
+                                    PatternRewriter &rewriter) {
+    Location loc = producer.getLoc();
+    auto dstShape = consumer.getResultType().cast<ShapedType>().getShape();
+    SmallVector<Value, 4> args;
+    for (auto arg : producer.getOperation()->getOperands()) {
+      auto type = RankedTensorType::get(
+          dstShape, arg.getType().cast<ShapedType>().getElementType());
+      args.push_back(rewriter.createOrFold<linalg::TensorReshapeOp>(
+          loc, type, arg, consumer.reassociation()));
+    }
+
+    SmallVector<Type, 4> resultTypes;
+    for (auto t : producer.getOutputTensorTypes()) {
+      Type type = RankedTensorType::get(dstShape,
+                                        t.cast<ShapedType>().getElementType());
+      resultTypes.push_back(type);
+    }
+
+    int rank = dstShape.size();
+    int numArgsIn = producer.getNumInputs();
+    int numArgsOut = producer.getNumOutputs();
+    auto genericOp = rewriter.create<linalg::GenericOp>(
+        loc, resultTypes, args, numArgsIn, numArgsOut,
+        SmallVector<AffineMap, 3>(args.size() + resultTypes.size(),
+                                  rewriter.getMultiDimIdentityMap(rank)),
+        SmallVector<StringRef, 3>(rank, getParallelIteratorTypeName()));
+    Region &region = genericOp.getRegion();
+    rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), region,
+                               region.begin());
+    return cast<LinalgOp>(genericOp.getOperation());
+  }
+
+  static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer,
+                       unsigned consumerIdx, PatternRewriter &rewriter,
+                       OperationFolder *folder = nullptr) {
+    if (isCollapsingAndFusible(producer, consumer, consumerIdx))
+      return fuseCollapsingCase(producer, consumer, consumerIdx, rewriter);
+    if (isExpandingAndFusible(producer, consumer, consumerIdx))
+      return fuseExpandingCase(producer, consumer, consumerIdx, rewriter);
+    return nullptr;
+  }
 };
 
 /// Implementation of fusion on tensor ops when producer is a splat constant.

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 4e7f1f6152b1..ac2d3e260a46 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -222,6 +222,40 @@ func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
 
 // -----
 
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d2)>
+
+func @generic_op_reshape_consumer_expanding(%arg0: tensor<264x4xf32>)
+                                            -> tensor<8x33x4xf32> {
+  %cst = constant dense<2.000000e+00> : tensor<264x4xf32>
+  %0 = linalg.generic
+    {args_in = 2 : i64, args_out = 1 : i64,
+     indexing_maps = [#map0, #map0, #map0],
+     iterator_types = ["parallel", "parallel"]}
+    %arg0, %cst {
+    ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
+      %2 = mulf %arg1, %arg2 : f32
+      linalg.yield %2 : f32
+    }: tensor<264x4xf32>, tensor<264x4xf32> -> tensor<264x4xf32>
+  %1 = linalg.tensor_reshape %0 [#map1, #map2] :
+    tensor<264x4xf32> into tensor<8x33x4xf32>
+  return %1 : tensor<8x33x4xf32>
+}
+
+// The reshape op in `%arg0` is folded into the indexing map of generic op.
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0 * 33 + d1, d2)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//       CHECK: func @generic_op_reshape_consumer_expanding
+//   CHECK-NOT:   linalg.tensor_reshape
+//       CHECK:   %[[CST:.*]] = constant {{.*}} : f32
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//       CHECK:   tensor<264x4xf32> -> tensor<8x33x4xf32>
+//   CHECK-NOT:   linalg.tensor_reshape
+
+// -----
+
 #map0 = affine_map<(d0, d1, d2) -> (d0)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>


        


More information about the Mlir-commits mailing list