[Mlir-commits] [mlir] 32288d3 - [mli][Linalg] NFC: Refactor methods in `ElementwiseOpFusion`.

Mahesh Ravishankar llvmlistbot at llvm.org
Thu Feb 3 10:54:06 PST 2022


Author: Mahesh Ravishankar
Date: 2022-02-03T18:53:13Z
New Revision: 32288d3722b6f06966eb14dcaa0e7a6fd0af077e

URL: https://github.com/llvm/llvm-project/commit/32288d3722b6f06966eb14dcaa0e7a6fd0af077e
DIFF: https://github.com/llvm/llvm-project/commit/32288d3722b6f06966eb14dcaa0e7a6fd0af077e.diff

LOG: [mli][Linalg] NFC: Refactor methods in `ElementwiseOpFusion`.

Reorder the methods and patterns to move related patterns/methods
closer (textually).

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D118870

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 32fd370012c44..a30263990500d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -27,6 +27,10 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
+//===---------------------------------------------------------------------===//
+// Methods and patterns that fuse elementwise `linalg.generic` operations.
+//===---------------------------------------------------------------------===//
+
 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
 /// the `producer` to use in the fused operation given the indexing map of the
 /// result of the producer in the consumer.
@@ -345,6 +349,58 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
   return SmallVector<Value>(fusedOp->getResults());
 }
 
+static Optional<SmallVector<Value>>
+fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
+                   GenericOp producer,
+                   const ControlElementwiseOpsFusionFn &controlFn) {
+  if (producer->getNumResults() != 1)
+    return llvm::None;
+
+  return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
+                                rewriter);
+}
+
+namespace {
+/// Patterns to fuse a generic op, with the producer of its operands.
+class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
+public:
+  FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
+                     PatternBenefit benefit = 1)
+      : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    // Find the first operand that is defined by another generic op on tensors.
+    for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
+      auto producer =
+          dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
+      if (!producer || !producer.hasTensorSemantics())
+        continue;
+      Optional<SmallVector<Value>> fusedOpResults =
+          fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
+      if (fusedOpResults) {
+        rewriter.replaceOp(genericOp, *fusedOpResults);
+        return success();
+      }
+    }
+    return failure();
+  }
+
+private:
+  ControlElementwiseOpsFusionFn controlFn;
+};
+} // namespace
+
+//===---------------------------------------------------------------------===//
+// Methods and patterns that fuse reshape ops with elementwise operations by
+// linearization of indexing maps.
+//===---------------------------------------------------------------------===//
+
+// TODO(ravishankarm): These patterns need to be deprecated. The indexing maps
+// these produce in the general case are detrimental to transformations.
+// They are useful now only in the limited case of unit-dimension folding.
+// Remove these in favor of more general folding by dimension contraction.
+
 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
 /// provided, given the shape of the source tensor that corresponds to the
 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
@@ -445,6 +501,157 @@ static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
   return true;
 }
 
+namespace {
+/// Pattern to fold tensor_expand_shape op with its consumer by using the source
+/// of the reshape op as the operand in the consumer (instead of the result of
+/// the tensor_collapse_shape). The corresponding index map in the consumer
+/// needs to be modified to linearize the folded dimension.
+///
+/// For example,
+///
+/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
+///      tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
+/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
+///        ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
+///        -> tensor<?x?x4x?xf32>
+///
+/// can be folded into
+///
+/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
+///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
+///        -> tensor<?x?x4x?xf32>
+template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
+struct FoldProducerReshapeOpByLinearization
+    : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    if (!genericOp.hasTensorSemantics())
+      return failure();
+    SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
+    for (const auto &en : llvm::enumerate(inputOperands)) {
+      auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
+      if (!reshapeOp)
+        continue;
+
+      if (!isTensorReshapeOpFoldableByLinearization(
+              reshapeOp, genericOp.getTiedIndexingMap(en.value()),
+              /*asProducer =*/true) ||
+          (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
+        continue;
+
+      // Compute the fused operands list,
+      SmallVector<Value> fusedOperands = genericOp.getInputOperands();
+      fusedOperands[en.index()] = reshapeOp.src();
+      SmallVector<Value> outputOperands = genericOp.getOutputOperands();
+      llvm::append_range(fusedOperands, outputOperands);
+
+      // Compute indexing_maps for the fused operation. The indexing_maps for
+      // the operands of the consumers that arent fused are the same.
+      SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
+
+      // Compute the indexing map to use for the result of the producer.
+      AffineMap modifiedMap =
+          linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
+      // The modified map cannot have symbols.
+      if (modifiedMap.getNumSymbols())
+        return failure();
+      for (AffineExpr expr : modifiedMap.getResults()) {
+        if (!expr.isPureAffine())
+          return failure();
+      }
+      fusedIndexMaps[en.index()] = modifiedMap;
+
+      // Further check that the resulting index maps can be fused and
+      // inverted. Without this the resultant op is not legal.
+      if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
+        return rewriter.notifyMatchFailure(
+            genericOp, "fused op loop bound computation failed");
+      }
+
+      rewriter.startRootUpdate(genericOp);
+      genericOp->setOperands(fusedOperands);
+      genericOp.indexing_mapsAttr(
+          rewriter.getAffineMapArrayAttr(fusedIndexMaps));
+      rewriter.finalizeRootUpdate(genericOp);
+      return success();
+    }
+    return failure();
+  }
+};
+
+/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
+/// producer. The corresponding index map in the consumer needs to be modified
+/// to linearize the folded dimension.
+template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
+struct FoldConsumerReshapeOpByLinearization
+    : public OpRewritePattern<TensorReshapeOp> {
+  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
+    if (!producer || !producer.hasTensorSemantics() ||
+        producer.getNumOutputs() != 1 ||
+        !isTensorReshapeOpFoldableByLinearization(
+            reshapeOp,
+            producer.getTiedIndexingMap(producer.getOutputOperand(0)),
+            /*asProducer =*/false) ||
+        (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
+      return failure();
+    // The indexing_maps for the operands of the fused operation are same as
+    // those for the operands of the producer.
+    SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
+
+    // Compute the indexing map to use for the operand of the producer.
+    AffineMap modifiedMap = linearizeCollapsedDims(
+        producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
+    for (AffineExpr expr : modifiedMap.getResults()) {
+      if (!expr.isPureAffine()) {
+        return rewriter.notifyMatchFailure(
+            producer, "fused op indexing map is not affine");
+      }
+    }
+    fusedIndexMaps.back() = modifiedMap;
+
+    // Further check that the resulting index maps can be fused and
+    // inverted. Without this the resultant op is not legal.
+    if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
+      return rewriter.notifyMatchFailure(
+          producer, "fused op loop bound computation failed");
+    }
+
+    Location loc = producer.getLoc();
+    SmallVector<Value> inputOperands = producer.getInputOperands();
+    Value output = rewriter.create<TensorReshapeOp>(
+        loc, producer.getOutputOperand(0)->get(),
+        reshapeOp.getReassociationExprs());
+    auto fusedOp = rewriter.create<GenericOp>(
+        loc, reshapeOp.getResultType(),
+        /*inputs=*/inputOperands,
+        // TODO: handle outputs.
+        /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+        producer.iterator_types(),
+        /*doc=*/nullptr,
+        /*library_call=*/nullptr);
+    auto &fusedRegion = fusedOp->getRegion(0);
+    rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
+                               fusedRegion.begin());
+    rewriter.replaceOp(reshapeOp, fusedOp->getResults());
+    return success();
+  }
+};
+} // namespace
+
+//===---------------------------------------------------------------------===//
+// Methods and patterns that fuse reshape ops with elementwise operations by
+// expanding the dimensionality of the elementwise operations.
+//===---------------------------------------------------------------------===//
+
 /// Conditions for folding a generic operation with a reshape op by expanding
 /// the iteration space dimensionality for tensor operations. These are
 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
@@ -612,9 +819,9 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
 /// Note that this could be extended to handle dynamic case, but the
 /// implementation below uses `affine.apply` which seems to have issues when the
 /// shapes are not static.
-LogicalResult isGenericOpExpandable(GenericOp genericOp,
-                                    const ExpansionInfo &expansionInfo,
-                                    PatternRewriter &rewriter) {
+static LogicalResult isGenericOpExpandable(GenericOp genericOp,
+                                           const ExpansionInfo &expansionInfo,
+                                           PatternRewriter &rewriter) {
   if (!genericOp.hasIndexSemantics())
     return success();
   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
@@ -863,88 +1070,85 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
 
 namespace {
 
-/// Pattern to fold tensor_expand_shape op with its consumer by using the source
-/// of the reshape op as the operand in the consumer (instead of the result of
-/// the tensor_collapse_shape). The corresponding index map in the consumer
-/// needs to be modified to linearize the folded dimension.
-///
-/// For example,
-///
-/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
-///      tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
-/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
-///        ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
-///        -> tensor<?x?x4x?xf32>
-///
-/// can be folded into
-///
-/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
-///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
-///        -> tensor<?x?x4x?xf32>
-template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
-struct FoldProducerReshapeOpByLinearization
+/// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
+/// when the reshape op is collapsing dimensions. The dimensionality of the loop
+/// in the consumer is expanded.
+class FoldWithProducerReshapeOpByExpansion
     : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
+public:
+  FoldWithProducerReshapeOpByExpansion(
+      MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
+      PatternBenefit benefit = 1)
+      : OpRewritePattern<GenericOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    if (!genericOp.hasTensorSemantics())
-      return failure();
-    SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
-    for (const auto &en : llvm::enumerate(inputOperands)) {
-      auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
+    for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
+      tensor::CollapseShapeOp reshapeOp =
+          opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
       if (!reshapeOp)
         continue;
-
-      if (!isTensorReshapeOpFoldableByLinearization(
-              reshapeOp, genericOp.getTiedIndexingMap(en.value()),
-              /*asProducer =*/true) ||
-          (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
+      // Fold only if
+      // - The tensor reshape op is folding.
+      // - All constraints of fusing with reshape by expansion are met.
+      if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
+          (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
         continue;
 
-      // Compute the fused operands list,
-      SmallVector<Value> fusedOperands = genericOp.getInputOperands();
-      fusedOperands[en.index()] = reshapeOp.src();
-      SmallVector<Value> outputOperands = genericOp.getOutputOperands();
-      llvm::append_range(fusedOperands, outputOperands);
-
-      // Compute indexing_maps for the fused operation. The indexing_maps for
-      // the operands of the consumers that arent fused are the same.
-      SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
-
-      // Compute the indexing map to use for the result of the producer.
-      AffineMap modifiedMap =
-          linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
-      // The modified map cannot have symbols.
-      if (modifiedMap.getNumSymbols())
+      Optional<SmallVector<Value>> replacementValues =
+          fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
+      if (!replacementValues)
         return failure();
-      for (AffineExpr expr : modifiedMap.getResults()) {
-        if (!expr.isPureAffine())
-          return failure();
-      }
-      fusedIndexMaps[en.index()] = modifiedMap;
-
-      // Further check that the resulting index maps can be fused and
-      // inverted. Without this the resultant op is not legal.
-      if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
-        return rewriter.notifyMatchFailure(
-            genericOp, "fused op loop bound computation failed");
-      }
-
-      rewriter.startRootUpdate(genericOp);
-      genericOp->setOperands(fusedOperands);
-      genericOp.indexing_mapsAttr(
-          rewriter.getAffineMapArrayAttr(fusedIndexMaps));
-      rewriter.finalizeRootUpdate(genericOp);
+      rewriter.replaceOp(genericOp, replacementValues.getValue());
       return success();
     }
     return failure();
   }
+
+private:
+  ControlElementwiseOpsFusionFn controlFoldingReshapes;
 };
 
+/// Pattern to fold a tensor_expand_shape op with its producer generic op
+/// by expanding the dimensionality of the loop in the producer op.
+struct FoldReshapeWithGenericOpByExpansion
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+  FoldReshapeWithGenericOpByExpansion(
+      MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
+      PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    // Fold only if all constraints of fusing with reshape by expansion are met.
+    GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
+    if (!producer || producer.getNumOutputs() != 1 ||
+        !isFusableWithReshapeByDimExpansion(producer,
+                                            producer.getOutputOperand(0)) ||
+        !controlFoldingReshapes(producer->getResult(0),
+                                reshapeOp->getOpOperand(0)))
+      return failure();
+    Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
+        producer, reshapeOp, producer.getOutputOperand(0), rewriter);
+    if (!replacementValues)
+      return failure();
+    rewriter.replaceOp(reshapeOp, replacementValues.getValue());
+    return success();
+  }
+
+private:
+  ControlElementwiseOpsFusionFn controlFoldingReshapes;
+};
+} // namespace
+
+//===---------------------------------------------------------------------===//
+// Methods and patterns to convert tensor.expand_shape -> linalg.generic
+// into linalg.generic -> tensor.expand_shape, i.e. push the reshape down.
+//===---------------------------------------------------------------------===//
+
 static SmallVector<ReassociationIndices>
 getReassociationIndices(ArrayRef<AffineMap> maps) {
   SmallVector<ReassociationIndices> reassociation;
@@ -959,6 +1163,7 @@ getReassociationIndices(ArrayRef<AffineMap> maps) {
   return reassociation;
 }
 
+namespace {
 /// Pattern to move rank reducing reshape after an elementwise linalg generic
 /// op. This is useful to expose more fusion opportunities between named ops and
 /// generic ops. This can only be done if there is no broadcast or permuation
@@ -1100,142 +1305,13 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
     return success();
   }
 };
+} // namespace
 
-/// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
-/// when the reshape op is collapsing dimensions. The dimensionality of the loop
-/// in the consumer is expanded.
-class FoldWithProducerReshapeOpByExpansion
-    : public OpRewritePattern<GenericOp> {
-public:
-  FoldWithProducerReshapeOpByExpansion(
-      MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
-      PatternBenefit benefit = 1)
-      : OpRewritePattern<GenericOp>(context, benefit),
-        controlFoldingReshapes(std::move(foldReshapes)) {}
-
-  LogicalResult matchAndRewrite(GenericOp genericOp,
-                                PatternRewriter &rewriter) const override {
-    for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
-      tensor::CollapseShapeOp reshapeOp =
-          opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
-      if (!reshapeOp)
-        continue;
-      // Fold only if
-      // - The tensor reshape op is folding.
-      // - All constraints of fusing with reshape by expansion are met.
-      if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
-          (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
-        continue;
-
-      Optional<SmallVector<Value>> replacementValues =
-          fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
-      if (!replacementValues)
-        return failure();
-      rewriter.replaceOp(genericOp, replacementValues.getValue());
-      return success();
-    }
-    return failure();
-  }
-
-private:
-  ControlElementwiseOpsFusionFn controlFoldingReshapes;
-};
-
-/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
-/// producer. The corresponding index map in the consumer needs to be modified
-/// to linearize the folded dimension.
-template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
-struct FoldConsumerReshapeOpByLinearization
-    : public OpRewritePattern<TensorReshapeOp> {
-  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
-                                PatternRewriter &rewriter) const override {
-    GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
-    if (!producer || !producer.hasTensorSemantics() ||
-        producer.getNumOutputs() != 1 ||
-        !isTensorReshapeOpFoldableByLinearization(
-            reshapeOp,
-            producer.getTiedIndexingMap(producer.getOutputOperand(0)),
-            /*asProducer =*/false) ||
-        (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
-      return failure();
-    // The indexing_maps for the operands of the fused operation are same as
-    // those for the operands of the producer.
-    SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
-
-    // Compute the indexing map to use for the operand of the producer.
-    AffineMap modifiedMap = linearizeCollapsedDims(
-        producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
-    for (AffineExpr expr : modifiedMap.getResults()) {
-      if (!expr.isPureAffine()) {
-        return rewriter.notifyMatchFailure(
-            producer, "fused op indexing map is not affine");
-      }
-    }
-    fusedIndexMaps.back() = modifiedMap;
-
-    // Further check that the resulting index maps can be fused and
-    // inverted. Without this the resultant op is not legal.
-    if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
-      return rewriter.notifyMatchFailure(
-          producer, "fused op loop bound computation failed");
-    }
-
-    Location loc = producer.getLoc();
-    SmallVector<Value> inputOperands = producer.getInputOperands();
-    Value output = rewriter.create<TensorReshapeOp>(
-        loc, producer.getOutputOperand(0)->get(),
-        reshapeOp.getReassociationExprs());
-    auto fusedOp = rewriter.create<GenericOp>(
-        loc, reshapeOp.getResultType(),
-        /*inputs=*/inputOperands,
-        // TODO: handle outputs.
-        /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
-        producer.iterator_types(),
-        /*doc=*/nullptr,
-        /*library_call=*/nullptr);
-    auto &fusedRegion = fusedOp->getRegion(0);
-    rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
-                               fusedRegion.begin());
-    rewriter.replaceOp(reshapeOp, fusedOp->getResults());
-    return success();
-  }
-};
-
-/// Pattern to fold a tensor_expand_shape op with its producer generic op
-/// by expanding the dimensionality of the loop in the producer op.
-struct FoldReshapeWithGenericOpByExpansion
-    : public OpRewritePattern<tensor::ExpandShapeOp> {
-
-  FoldReshapeWithGenericOpByExpansion(
-      MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
-      PatternBenefit benefit = 1)
-      : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
-        controlFoldingReshapes(std::move(foldReshapes)) {}
-
-  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
-                                PatternRewriter &rewriter) const override {
-    // Fold only if all constraints of fusing with reshape by expansion are met.
-    GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
-    if (!producer || producer.getNumOutputs() != 1 ||
-        !isFusableWithReshapeByDimExpansion(producer,
-                                            producer.getOutputOperand(0)) ||
-        !controlFoldingReshapes(producer->getResult(0),
-                                reshapeOp->getOpOperand(0)))
-      return failure();
-    Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
-        producer, reshapeOp, producer.getOutputOperand(0), rewriter);
-    if (!replacementValues)
-      return failure();
-    rewriter.replaceOp(reshapeOp, replacementValues.getValue());
-    return success();
-  }
-
-private:
-  ControlElementwiseOpsFusionFn controlFoldingReshapes;
-};
+//===---------------------------------------------------------------------===//
+// Methods and patterns that fuse constants with linalg.generic operations.
+//===---------------------------------------------------------------------===//
 
+namespace {
 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
 /// handle cases where the constant is not single-valued.
 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
@@ -1624,98 +1700,11 @@ struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
 
 } // namespace
 
-static Optional<SmallVector<Value>>
-fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
-                   GenericOp producer,
-                   const ControlElementwiseOpsFusionFn &controlFn) {
-  if (producer->getNumResults() != 1)
-    return llvm::None;
-
-  return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
-                                rewriter);
-}
-
-bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
-                                      OpOperand &consumer) {
-  if (auto producerCollapseOp =
-          dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
-    return !isUnitDimExpansionOnly(producerCollapseOp);
-  }
-  if (auto consumerExpandOp =
-          dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
-    return !isUnitDimExpansionOnly(consumerExpandOp);
-  }
-  return true;
-}
+//===---------------------------------------------------------------------===//
+// Miscellaneous patterns that help fusion.
+//===---------------------------------------------------------------------===//
 
 namespace {
-/// Patterns to fuse a generic op, with the producer of its operands.
-class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
-public:
-  FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
-                     PatternBenefit benefit = 1)
-      : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
-
-  LogicalResult matchAndRewrite(GenericOp genericOp,
-                                PatternRewriter &rewriter) const override {
-    // Find the first operand that is defined by another generic op on tensors.
-    for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
-      auto producer =
-          dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
-      if (!producer || !producer.hasTensorSemantics())
-        continue;
-      Optional<SmallVector<Value>> fusedOpResults =
-          fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
-      if (fusedOpResults) {
-        rewriter.replaceOp(genericOp, *fusedOpResults);
-        return success();
-      }
-    }
-    return failure();
-  }
-
-private:
-  ControlElementwiseOpsFusionFn controlFn;
-};
-
-/// Pass that fuses generic ops on tensors. Used only for testing.
-struct LinalgElementwiseOpFusionPass
-    : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
-  void runOnOperation() override {
-    Operation *op = getOperation();
-    RewritePatternSet patterns(op->getContext());
-    ControlElementwiseOpsFusionFn allowFoldingFn =
-        [](const OpResult &producer, const OpOperand &consumer) {
-          return true;
-        };
-    populateElementwiseOpsFusionPatterns(
-        patterns,
-        LinalgElementwiseFusionOptions().setControlFoldingReshapes(
-            allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
-
-    // Use TopDownTraversal for compile time reasons
-    GreedyRewriteConfig grc;
-    grc.useTopDownTraversal = true;
-    (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
-                                       grc);
-  }
-};
-
-/// Pass to test folding of reshape ops with generic ops by linearization.
-struct FoldReshapeOpsByLinearizationPass
-    : public LinalgFoldReshapeOpsByLinearizationBase<
-          FoldReshapeOpsByLinearizationPass> {
-  void runOnOperation() override {
-    Operation *op = getOperation();
-    RewritePatternSet patterns(op->getContext());
-    populateFoldReshapeOpsByLinearizationPatterns(patterns);
-    if (allowFoldingUnitDimReshapes) {
-      populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
-    }
-    (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
-  }
-};
-
 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
 /// the value of the `outs` operand is not used within the op.  This is only
 /// implemented for `linalg.generic` operations for now, but should hold for all
@@ -1761,9 +1750,12 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
     return success();
   }
 };
-
 } // namespace
 
+//===---------------------------------------------------------------------===//
+// Methods that add patterns descrined in this file to a pattern list.
+//===---------------------------------------------------------------------===//
+
 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
     RewritePatternSet &patterns) {
   patterns
@@ -1815,6 +1807,65 @@ void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
   patterns.add<PushExpandingReshape>(context);
 }
 
+//===---------------------------------------------------------------------===//
+// Passes
+//===---------------------------------------------------------------------===//
+
+bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
+                                      OpOperand &consumer) {
+  if (auto producerCollapseOp =
+          dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
+    return !isUnitDimExpansionOnly(producerCollapseOp);
+  }
+  if (auto consumerExpandOp =
+          dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
+    return !isUnitDimExpansionOnly(consumerExpandOp);
+  }
+  return true;
+}
+
+namespace {
+
+/// Pass that fuses generic ops on tensors. Used only for testing.
+struct LinalgElementwiseOpFusionPass
+    : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(op->getContext());
+    ControlElementwiseOpsFusionFn allowFoldingFn =
+        [](const OpResult &producer, const OpOperand &consumer) {
+          return true;
+        };
+    populateElementwiseOpsFusionPatterns(
+        patterns,
+        LinalgElementwiseFusionOptions().setControlFoldingReshapes(
+            allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
+
+    // Use TopDownTraversal for compile time reasons
+    GreedyRewriteConfig grc;
+    grc.useTopDownTraversal = true;
+    (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
+                                       grc);
+  }
+};
+
+/// Pass to test folding of reshape ops with generic ops by linearization.
+struct FoldReshapeOpsByLinearizationPass
+    : public LinalgFoldReshapeOpsByLinearizationBase<
+          FoldReshapeOpsByLinearizationPass> {
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(op->getContext());
+    populateFoldReshapeOpsByLinearizationPatterns(patterns);
+    if (allowFoldingUnitDimReshapes) {
+      populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
+    }
+    (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
+  }
+};
+
+} // namespace
+
 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
   return std::make_unique<LinalgElementwiseOpFusionPass>();
 }


        


More information about the Mlir-commits mailing list