[Mlir-commits] [mlir] 32288d3 - [mli][Linalg] NFC: Refactor methods in `ElementwiseOpFusion`.
Mahesh Ravishankar
llvmlistbot at llvm.org
Thu Feb 3 10:54:06 PST 2022
Author: Mahesh Ravishankar
Date: 2022-02-03T18:53:13Z
New Revision: 32288d3722b6f06966eb14dcaa0e7a6fd0af077e
URL: https://github.com/llvm/llvm-project/commit/32288d3722b6f06966eb14dcaa0e7a6fd0af077e
DIFF: https://github.com/llvm/llvm-project/commit/32288d3722b6f06966eb14dcaa0e7a6fd0af077e.diff
LOG: [mli][Linalg] NFC: Refactor methods in `ElementwiseOpFusion`.
Reorder the methods and patterns to move related patterns/methods
closer (textually).
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D118870
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 32fd370012c44..a30263990500d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -27,6 +27,10 @@
using namespace mlir;
using namespace mlir::linalg;
+//===---------------------------------------------------------------------===//
+// Methods and patterns that fuse elementwise `linalg.generic` operations.
+//===---------------------------------------------------------------------===//
+
/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
/// the `producer` to use in the fused operation given the indexing map of the
/// result of the producer in the consumer.
@@ -345,6 +349,58 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
return SmallVector<Value>(fusedOp->getResults());
}
+static Optional<SmallVector<Value>>
+fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
+ GenericOp producer,
+ const ControlElementwiseOpsFusionFn &controlFn) {
+ if (producer->getNumResults() != 1)
+ return llvm::None;
+
+ return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
+ rewriter);
+}
+
+namespace {
+/// Patterns to fuse a generic op, with the producer of its operands.
+class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
+public:
+ FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ // Find the first operand that is defined by another generic op on tensors.
+ for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
+ auto producer =
+ dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
+ if (!producer || !producer.hasTensorSemantics())
+ continue;
+ Optional<SmallVector<Value>> fusedOpResults =
+ fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
+ if (fusedOpResults) {
+ rewriter.replaceOp(genericOp, *fusedOpResults);
+ return success();
+ }
+ }
+ return failure();
+ }
+
+private:
+ ControlElementwiseOpsFusionFn controlFn;
+};
+} // namespace
+
+//===---------------------------------------------------------------------===//
+// Methods and patterns that fuse reshape ops with elementwise operations by
+// linearization of indexing maps.
+//===---------------------------------------------------------------------===//
+
+// TODO(ravishankarm): These patterns need to be deprecated. The indexing maps
+// these produce in the general case are detrimental to transformations.
+// They are useful now only in the limited case of unit-dimension folding.
+// Remove these in favor of more general folding by dimension contraction.
+
/// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
/// provided, given the shape of the source tensor that corresponds to the
/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
@@ -445,6 +501,157 @@ static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
return true;
}
+namespace {
+/// Pattern to fold tensor_expand_shape op with its consumer by using the source
+/// of the reshape op as the operand in the consumer (instead of the result of
+/// the tensor_collapse_shape). The corresponding index map in the consumer
+/// needs to be modified to linearize the folded dimension.
+///
+/// For example,
+///
+/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
+/// tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
+/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
+/// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
+/// -> tensor<?x?x4x?xf32>
+///
+/// can be folded into
+///
+/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
+/// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
+/// -> tensor<?x?x4x?xf32>
+template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
+struct FoldProducerReshapeOpByLinearization
+ : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ if (!genericOp.hasTensorSemantics())
+ return failure();
+ SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
+ for (const auto &en : llvm::enumerate(inputOperands)) {
+ auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
+ if (!reshapeOp)
+ continue;
+
+ if (!isTensorReshapeOpFoldableByLinearization(
+ reshapeOp, genericOp.getTiedIndexingMap(en.value()),
+ /*asProducer =*/true) ||
+ (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
+ continue;
+
+ // Compute the fused operands list,
+ SmallVector<Value> fusedOperands = genericOp.getInputOperands();
+ fusedOperands[en.index()] = reshapeOp.src();
+ SmallVector<Value> outputOperands = genericOp.getOutputOperands();
+ llvm::append_range(fusedOperands, outputOperands);
+
+ // Compute indexing_maps for the fused operation. The indexing_maps for
+ // the operands of the consumers that arent fused are the same.
+ SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
+
+ // Compute the indexing map to use for the result of the producer.
+ AffineMap modifiedMap =
+ linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
+ // The modified map cannot have symbols.
+ if (modifiedMap.getNumSymbols())
+ return failure();
+ for (AffineExpr expr : modifiedMap.getResults()) {
+ if (!expr.isPureAffine())
+ return failure();
+ }
+ fusedIndexMaps[en.index()] = modifiedMap;
+
+ // Further check that the resulting index maps can be fused and
+ // inverted. Without this the resultant op is not legal.
+ if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "fused op loop bound computation failed");
+ }
+
+ rewriter.startRootUpdate(genericOp);
+ genericOp->setOperands(fusedOperands);
+ genericOp.indexing_mapsAttr(
+ rewriter.getAffineMapArrayAttr(fusedIndexMaps));
+ rewriter.finalizeRootUpdate(genericOp);
+ return success();
+ }
+ return failure();
+ }
+};
+
+/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
+/// producer. The corresponding index map in the consumer needs to be modified
+/// to linearize the folded dimension.
+template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
+struct FoldConsumerReshapeOpByLinearization
+ : public OpRewritePattern<TensorReshapeOp> {
+ using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
+ if (!producer || !producer.hasTensorSemantics() ||
+ producer.getNumOutputs() != 1 ||
+ !isTensorReshapeOpFoldableByLinearization(
+ reshapeOp,
+ producer.getTiedIndexingMap(producer.getOutputOperand(0)),
+ /*asProducer =*/false) ||
+ (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
+ return failure();
+ // The indexing_maps for the operands of the fused operation are same as
+ // those for the operands of the producer.
+ SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
+
+ // Compute the indexing map to use for the operand of the producer.
+ AffineMap modifiedMap = linearizeCollapsedDims(
+ producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
+ for (AffineExpr expr : modifiedMap.getResults()) {
+ if (!expr.isPureAffine()) {
+ return rewriter.notifyMatchFailure(
+ producer, "fused op indexing map is not affine");
+ }
+ }
+ fusedIndexMaps.back() = modifiedMap;
+
+ // Further check that the resulting index maps can be fused and
+ // inverted. Without this the resultant op is not legal.
+ if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
+ return rewriter.notifyMatchFailure(
+ producer, "fused op loop bound computation failed");
+ }
+
+ Location loc = producer.getLoc();
+ SmallVector<Value> inputOperands = producer.getInputOperands();
+ Value output = rewriter.create<TensorReshapeOp>(
+ loc, producer.getOutputOperand(0)->get(),
+ reshapeOp.getReassociationExprs());
+ auto fusedOp = rewriter.create<GenericOp>(
+ loc, reshapeOp.getResultType(),
+ /*inputs=*/inputOperands,
+ // TODO: handle outputs.
+ /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+ producer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr);
+ auto &fusedRegion = fusedOp->getRegion(0);
+ rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
+ fusedRegion.begin());
+ rewriter.replaceOp(reshapeOp, fusedOp->getResults());
+ return success();
+ }
+};
+} // namespace
+
+//===---------------------------------------------------------------------===//
+// Methods and patterns that fuse reshape ops with elementwise operations by
+// expanding the dimensionality of the elementwise operations.
+//===---------------------------------------------------------------------===//
+
/// Conditions for folding a generic operation with a reshape op by expanding
/// the iteration space dimensionality for tensor operations. These are
/// preconditions assumed by `foldReshapeByDimExpansion` which implements the
@@ -612,9 +819,9 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
/// Note that this could be extended to handle dynamic case, but the
/// implementation below uses `affine.apply` which seems to have issues when the
/// shapes are not static.
-LogicalResult isGenericOpExpandable(GenericOp genericOp,
- const ExpansionInfo &expansionInfo,
- PatternRewriter &rewriter) {
+static LogicalResult isGenericOpExpandable(GenericOp genericOp,
+ const ExpansionInfo &expansionInfo,
+ PatternRewriter &rewriter) {
if (!genericOp.hasIndexSemantics())
return success();
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
@@ -863,88 +1070,85 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
namespace {
-/// Pattern to fold tensor_expand_shape op with its consumer by using the source
-/// of the reshape op as the operand in the consumer (instead of the result of
-/// the tensor_collapse_shape). The corresponding index map in the consumer
-/// needs to be modified to linearize the folded dimension.
-///
-/// For example,
-///
-/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
-/// tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
-/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
-/// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
-/// -> tensor<?x?x4x?xf32>
-///
-/// can be folded into
-///
-/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
-/// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
-/// -> tensor<?x?x4x?xf32>
-template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
-struct FoldProducerReshapeOpByLinearization
+/// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
+/// when the reshape op is collapsing dimensions. The dimensionality of the loop
+/// in the consumer is expanded.
+class FoldWithProducerReshapeOpByExpansion
: public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
+public:
+ FoldWithProducerReshapeOpByExpansion(
+ MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<GenericOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- if (!genericOp.hasTensorSemantics())
- return failure();
- SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
- for (const auto &en : llvm::enumerate(inputOperands)) {
- auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
+ for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
+ tensor::CollapseShapeOp reshapeOp =
+ opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
if (!reshapeOp)
continue;
-
- if (!isTensorReshapeOpFoldableByLinearization(
- reshapeOp, genericOp.getTiedIndexingMap(en.value()),
- /*asProducer =*/true) ||
- (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
+ // Fold only if
+ // - The tensor reshape op is folding.
+ // - All constraints of fusing with reshape by expansion are met.
+ if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
+ (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
continue;
- // Compute the fused operands list,
- SmallVector<Value> fusedOperands = genericOp.getInputOperands();
- fusedOperands[en.index()] = reshapeOp.src();
- SmallVector<Value> outputOperands = genericOp.getOutputOperands();
- llvm::append_range(fusedOperands, outputOperands);
-
- // Compute indexing_maps for the fused operation. The indexing_maps for
- // the operands of the consumers that arent fused are the same.
- SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
-
- // Compute the indexing map to use for the result of the producer.
- AffineMap modifiedMap =
- linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
- // The modified map cannot have symbols.
- if (modifiedMap.getNumSymbols())
+ Optional<SmallVector<Value>> replacementValues =
+ fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
+ if (!replacementValues)
return failure();
- for (AffineExpr expr : modifiedMap.getResults()) {
- if (!expr.isPureAffine())
- return failure();
- }
- fusedIndexMaps[en.index()] = modifiedMap;
-
- // Further check that the resulting index maps can be fused and
- // inverted. Without this the resultant op is not legal.
- if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
- return rewriter.notifyMatchFailure(
- genericOp, "fused op loop bound computation failed");
- }
-
- rewriter.startRootUpdate(genericOp);
- genericOp->setOperands(fusedOperands);
- genericOp.indexing_mapsAttr(
- rewriter.getAffineMapArrayAttr(fusedIndexMaps));
- rewriter.finalizeRootUpdate(genericOp);
+ rewriter.replaceOp(genericOp, replacementValues.getValue());
return success();
}
return failure();
}
+
+private:
+ ControlElementwiseOpsFusionFn controlFoldingReshapes;
};
+/// Pattern to fold a tensor_expand_shape op with its producer generic op
+/// by expanding the dimensionality of the loop in the producer op.
+struct FoldReshapeWithGenericOpByExpansion
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+ FoldReshapeWithGenericOpByExpansion(
+ MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ // Fold only if all constraints of fusing with reshape by expansion are met.
+ GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
+ if (!producer || producer.getNumOutputs() != 1 ||
+ !isFusableWithReshapeByDimExpansion(producer,
+ producer.getOutputOperand(0)) ||
+ !controlFoldingReshapes(producer->getResult(0),
+ reshapeOp->getOpOperand(0)))
+ return failure();
+ Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
+ producer, reshapeOp, producer.getOutputOperand(0), rewriter);
+ if (!replacementValues)
+ return failure();
+ rewriter.replaceOp(reshapeOp, replacementValues.getValue());
+ return success();
+ }
+
+private:
+ ControlElementwiseOpsFusionFn controlFoldingReshapes;
+};
+} // namespace
+
+//===---------------------------------------------------------------------===//
+// Methods and patterns to convert tensor.expand_shape -> linalg.generic
+// into linalg.generic -> tensor.expand_shape, i.e. push the reshape down.
+//===---------------------------------------------------------------------===//
+
static SmallVector<ReassociationIndices>
getReassociationIndices(ArrayRef<AffineMap> maps) {
SmallVector<ReassociationIndices> reassociation;
@@ -959,6 +1163,7 @@ getReassociationIndices(ArrayRef<AffineMap> maps) {
return reassociation;
}
+namespace {
/// Pattern to move rank reducing reshape after an elementwise linalg generic
/// op. This is useful to expose more fusion opportunities between named ops and
/// generic ops. This can only be done if there is no broadcast or permuation
@@ -1100,142 +1305,13 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
return success();
}
};
+} // namespace
-/// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
-/// when the reshape op is collapsing dimensions. The dimensionality of the loop
-/// in the consumer is expanded.
-class FoldWithProducerReshapeOpByExpansion
- : public OpRewritePattern<GenericOp> {
-public:
- FoldWithProducerReshapeOpByExpansion(
- MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
- PatternBenefit benefit = 1)
- : OpRewritePattern<GenericOp>(context, benefit),
- controlFoldingReshapes(std::move(foldReshapes)) {}
-
- LogicalResult matchAndRewrite(GenericOp genericOp,
- PatternRewriter &rewriter) const override {
- for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
- tensor::CollapseShapeOp reshapeOp =
- opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
- if (!reshapeOp)
- continue;
- // Fold only if
- // - The tensor reshape op is folding.
- // - All constraints of fusing with reshape by expansion are met.
- if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
- (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
- continue;
-
- Optional<SmallVector<Value>> replacementValues =
- fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
- if (!replacementValues)
- return failure();
- rewriter.replaceOp(genericOp, replacementValues.getValue());
- return success();
- }
- return failure();
- }
-
-private:
- ControlElementwiseOpsFusionFn controlFoldingReshapes;
-};
-
-/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
-/// producer. The corresponding index map in the consumer needs to be modified
-/// to linearize the folded dimension.
-template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
-struct FoldConsumerReshapeOpByLinearization
- : public OpRewritePattern<TensorReshapeOp> {
- using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
- PatternRewriter &rewriter) const override {
- GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
- if (!producer || !producer.hasTensorSemantics() ||
- producer.getNumOutputs() != 1 ||
- !isTensorReshapeOpFoldableByLinearization(
- reshapeOp,
- producer.getTiedIndexingMap(producer.getOutputOperand(0)),
- /*asProducer =*/false) ||
- (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
- return failure();
- // The indexing_maps for the operands of the fused operation are same as
- // those for the operands of the producer.
- SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
-
- // Compute the indexing map to use for the operand of the producer.
- AffineMap modifiedMap = linearizeCollapsedDims(
- producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
- for (AffineExpr expr : modifiedMap.getResults()) {
- if (!expr.isPureAffine()) {
- return rewriter.notifyMatchFailure(
- producer, "fused op indexing map is not affine");
- }
- }
- fusedIndexMaps.back() = modifiedMap;
-
- // Further check that the resulting index maps can be fused and
- // inverted. Without this the resultant op is not legal.
- if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
- return rewriter.notifyMatchFailure(
- producer, "fused op loop bound computation failed");
- }
-
- Location loc = producer.getLoc();
- SmallVector<Value> inputOperands = producer.getInputOperands();
- Value output = rewriter.create<TensorReshapeOp>(
- loc, producer.getOutputOperand(0)->get(),
- reshapeOp.getReassociationExprs());
- auto fusedOp = rewriter.create<GenericOp>(
- loc, reshapeOp.getResultType(),
- /*inputs=*/inputOperands,
- // TODO: handle outputs.
- /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
- producer.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr);
- auto &fusedRegion = fusedOp->getRegion(0);
- rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
- fusedRegion.begin());
- rewriter.replaceOp(reshapeOp, fusedOp->getResults());
- return success();
- }
-};
-
-/// Pattern to fold a tensor_expand_shape op with its producer generic op
-/// by expanding the dimensionality of the loop in the producer op.
-struct FoldReshapeWithGenericOpByExpansion
- : public OpRewritePattern<tensor::ExpandShapeOp> {
-
- FoldReshapeWithGenericOpByExpansion(
- MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
- PatternBenefit benefit = 1)
- : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
- controlFoldingReshapes(std::move(foldReshapes)) {}
-
- LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
- PatternRewriter &rewriter) const override {
- // Fold only if all constraints of fusing with reshape by expansion are met.
- GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
- if (!producer || producer.getNumOutputs() != 1 ||
- !isFusableWithReshapeByDimExpansion(producer,
- producer.getOutputOperand(0)) ||
- !controlFoldingReshapes(producer->getResult(0),
- reshapeOp->getOpOperand(0)))
- return failure();
- Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
- producer, reshapeOp, producer.getOutputOperand(0), rewriter);
- if (!replacementValues)
- return failure();
- rewriter.replaceOp(reshapeOp, replacementValues.getValue());
- return success();
- }
-
-private:
- ControlElementwiseOpsFusionFn controlFoldingReshapes;
-};
+//===---------------------------------------------------------------------===//
+// Methods and patterns that fuse constants with linalg.generic operations.
+//===---------------------------------------------------------------------===//
+namespace {
/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
/// handle cases where the constant is not single-valued.
class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
@@ -1624,98 +1700,11 @@ struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
} // namespace
-static Optional<SmallVector<Value>>
-fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
- GenericOp producer,
- const ControlElementwiseOpsFusionFn &controlFn) {
- if (producer->getNumResults() != 1)
- return llvm::None;
-
- return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
- rewriter);
-}
-
-bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
- OpOperand &consumer) {
- if (auto producerCollapseOp =
- dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
- return !isUnitDimExpansionOnly(producerCollapseOp);
- }
- if (auto consumerExpandOp =
- dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
- return !isUnitDimExpansionOnly(consumerExpandOp);
- }
- return true;
-}
+//===---------------------------------------------------------------------===//
+// Miscellaneous patterns that help fusion.
+//===---------------------------------------------------------------------===//
namespace {
-/// Patterns to fuse a generic op, with the producer of its operands.
-class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
-public:
- FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
- PatternBenefit benefit = 1)
- : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
-
- LogicalResult matchAndRewrite(GenericOp genericOp,
- PatternRewriter &rewriter) const override {
- // Find the first operand that is defined by another generic op on tensors.
- for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
- auto producer =
- dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
- if (!producer || !producer.hasTensorSemantics())
- continue;
- Optional<SmallVector<Value>> fusedOpResults =
- fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
- if (fusedOpResults) {
- rewriter.replaceOp(genericOp, *fusedOpResults);
- return success();
- }
- }
- return failure();
- }
-
-private:
- ControlElementwiseOpsFusionFn controlFn;
-};
-
-/// Pass that fuses generic ops on tensors. Used only for testing.
-struct LinalgElementwiseOpFusionPass
- : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
- void runOnOperation() override {
- Operation *op = getOperation();
- RewritePatternSet patterns(op->getContext());
- ControlElementwiseOpsFusionFn allowFoldingFn =
- [](const OpResult &producer, const OpOperand &consumer) {
- return true;
- };
- populateElementwiseOpsFusionPatterns(
- patterns,
- LinalgElementwiseFusionOptions().setControlFoldingReshapes(
- allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
-
- // Use TopDownTraversal for compile time reasons
- GreedyRewriteConfig grc;
- grc.useTopDownTraversal = true;
- (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
- grc);
- }
-};
-
-/// Pass to test folding of reshape ops with generic ops by linearization.
-struct FoldReshapeOpsByLinearizationPass
- : public LinalgFoldReshapeOpsByLinearizationBase<
- FoldReshapeOpsByLinearizationPass> {
- void runOnOperation() override {
- Operation *op = getOperation();
- RewritePatternSet patterns(op->getContext());
- populateFoldReshapeOpsByLinearizationPatterns(patterns);
- if (allowFoldingUnitDimReshapes) {
- populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
- }
- (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
- }
-};
-
/// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
/// the value of the `outs` operand is not used within the op. This is only
/// implemented for `linalg.generic` operations for now, but should hold for all
@@ -1761,9 +1750,12 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
return success();
}
};
-
} // namespace
+//===---------------------------------------------------------------------===//
+// Methods that add patterns descrined in this file to a pattern list.
+//===---------------------------------------------------------------------===//
+
void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns) {
patterns
@@ -1815,6 +1807,65 @@ void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
patterns.add<PushExpandingReshape>(context);
}
+//===---------------------------------------------------------------------===//
+// Passes
+//===---------------------------------------------------------------------===//
+
+bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
+ OpOperand &consumer) {
+ if (auto producerCollapseOp =
+ dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
+ return !isUnitDimExpansionOnly(producerCollapseOp);
+ }
+ if (auto consumerExpandOp =
+ dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
+ return !isUnitDimExpansionOnly(consumerExpandOp);
+ }
+ return true;
+}
+
+namespace {
+
+/// Pass that fuses generic ops on tensors. Used only for testing.
+struct LinalgElementwiseOpFusionPass
+ : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ RewritePatternSet patterns(op->getContext());
+ ControlElementwiseOpsFusionFn allowFoldingFn =
+ [](const OpResult &producer, const OpOperand &consumer) {
+ return true;
+ };
+ populateElementwiseOpsFusionPatterns(
+ patterns,
+ LinalgElementwiseFusionOptions().setControlFoldingReshapes(
+ allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
+
+ // Use TopDownTraversal for compile time reasons
+ GreedyRewriteConfig grc;
+ grc.useTopDownTraversal = true;
+ (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
+ grc);
+ }
+};
+
+/// Pass to test folding of reshape ops with generic ops by linearization.
+struct FoldReshapeOpsByLinearizationPass
+ : public LinalgFoldReshapeOpsByLinearizationBase<
+ FoldReshapeOpsByLinearizationPass> {
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ RewritePatternSet patterns(op->getContext());
+ populateFoldReshapeOpsByLinearizationPatterns(patterns);
+ if (allowFoldingUnitDimReshapes) {
+ populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
+ }
+ (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
+ }
+};
+
+} // namespace
+
std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
return std::make_unique<LinalgElementwiseOpFusionPass>();
}
More information about the Mlir-commits
mailing list