[Mlir-commits] [mlir] [mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp (PR #68526)
Aviad Cohen
llvmlistbot at llvm.org
Sun Oct 22 08:07:38 PDT 2023
https://github.com/AviadCo updated https://github.com/llvm/llvm-project/pull/68526
>From 41c5c80b69b730d4f20120d62780a2849de99d99 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Sun, 8 Oct 2023 16:02:32 +0300
Subject: [PATCH] [mlir][linalg] Enable CollapseLinalgDimensions to collapse
linalg::CopyOp
---
.../Dialect/Linalg/Transforms/Transforms.h | 18 +-
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 186 ++++++++++--------
mlir/test/Dialect/Linalg/collapse-dim.mlir | 37 ++++
.../Linalg/TestLinalgElementwiseFusion.cpp | 2 +-
4 files changed, 150 insertions(+), 93 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 3597209d7f90c25..fbe2923c710aabb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1047,16 +1047,18 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
ArrayRef<ReassociationIndices> dimSequences);
-/// Collapses dimensions of linalg.generic operation. A precondition to
-/// calling this method is that for each list in `foldedIterationDim`, the
+/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
+/// to calling this method is that for each list in `foldedIterationDim`, the
/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
-/// the `genericOp`. This can be checked using `areDimSequencePreserved` method.
+/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
/// When valid, the method also collapses the operands of the op. Returns
-/// replacement values of the results of the original `genericOp` by inserting
+/// replacement values of the results of the original `linalgOp` by inserting
/// reshapes to get back values of compatible types.
-FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
- GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
- RewriterBase &rewriter);
+template <typename LinalgType>
+FailureOr<SmallVector<Value>>
+collapseOpIterationDims(LinalgType op,
+ ArrayRef<ReassociationIndices> foldedIterationDims,
+ RewriterBase &rewriter);
struct LowerPackResult {
tensor::PadOp padOp;
@@ -1515,7 +1517,7 @@ void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns);
/// to return an array of `ReassociationIndices` representing dimensions that
/// should be merged.
using GetCollapsableDimensionsFn =
- std::function<SmallVector<ReassociationIndices>(linalg::GenericOp)>;
+ std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>;
/// Pattern to collapse dimensions in a linalg.generic op. This will collapse
/// tensor operands when needed and expand back the result tensors.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6f4b0ff60ca97c6..35d7d86fd8f1d7a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1373,16 +1373,17 @@ getOperandReassociation(AffineMap indexingMap,
}
/// Get the new value to use for a given `OpOperand` in the collapsed operation.
-static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
+static Value getCollapsedOpOperand(Location loc, LinalgOp op,
OpOperand *opOperand,
const CollapsingInfo &collapsingInfo,
OpBuilder &builder) {
- AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
+ AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
SmallVector<ReassociationIndices> operandReassociation =
getOperandReassociation(indexingMap, collapsingInfo);
- // If the number of entries in the reassocation for the operand is same as the
- // number of results of the indexing map, then nothing to do for this operand.
+ // If the number of entries in the reassociation for the operand is same as
+ // the number of results of the indexing map, then nothing to do for this
+ // operand.
Value operand = opOperand->get();
if (operandReassociation.size() == indexingMap.getNumResults())
return operand;
@@ -1439,20 +1440,80 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
}
}
+template <typename LinalgType>
+Operation *createCollapsedOp(LinalgType op,
+ const CollapsingInfo &collapsingInfo,
+ RewriterBase &rewriter) {
+ static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
+ "unsupported linalg op type to create");
+ Location loc = op->getLoc();
+
+ // Get the input operands.
+ SmallVector<Value> inputOperands =
+ llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
+ return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
+ rewriter);
+ });
+
+ // Get the output operands and result types.
+ SmallVector<Type> resultTypes;
+ SmallVector<Value> outputOperands;
+ resultTypes.reserve(op.getNumDpsInits());
+ outputOperands.reserve(op.getNumDpsInits());
+ for (OpOperand &output : op.getDpsInitsMutable()) {
+ Value newOutput =
+ getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
+ outputOperands.push_back(newOutput);
+ // If the op has "buffer semantics", then the init operands are ranked
+ // memrefs and the op has no results.
+ if (!op.hasBufferSemantics())
+ resultTypes.push_back(newOutput.getType());
+ }
+
+ if (isa<linalg::CopyOp>(op)) {
+ return rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
+ outputOperands[0]);
+ }
+
+ // Get the iterator types for the operand.
+ SmallVector<utils::IteratorType> iteratorTypes =
+ getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
+
+ // Get the indexing maps.
+ auto indexingMaps =
+ llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
+ return getCollapsedOpIndexingMap(map, collapsingInfo);
+ });
+
+ Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
+ loc, resultTypes, inputOperands, outputOperands, indexingMaps,
+ iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
+ Block *origOpBlock = &op->getRegion(0).front();
+ Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
+ rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
+ collapsedOpBlock->getArguments());
+
+ return collapsedOp;
+}
+
/// Implementation of fusion with reshape operation by collapsing dimensions.
-FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
- GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
+template <typename LinalgType>
+FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
+ LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter) {
+ static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
+ "unsupported linalg op type to collapse");
+
// Bail on trivial no-op cases.
- if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
+ if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
return foldedDims.size() <= 1;
}))
return failure();
- bool hasBufferSemantics = genericOp.hasBufferSemantics();
+ bool hasBufferSemantics = op.hasBufferSemantics();
if (hasBufferSemantics &&
- !llvm::all_of(genericOp->getOperands(), [&](Value operand) -> bool {
+ !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
if (!memRefToCollapse)
return true;
@@ -1460,20 +1521,19 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
return memref::CollapseShapeOp::isGuaranteedCollapsible(
memRefToCollapse, foldedIterationDims);
}))
- return rewriter.notifyMatchFailure(genericOp,
+ return rewriter.notifyMatchFailure(op,
"memref is not guaranteed collapsible");
CollapsingInfo collapsingInfo;
- if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
- foldedIterationDims))) {
+ if (failed(
+ collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
return rewriter.notifyMatchFailure(
- genericOp, "illegal to collapse specified dimensions");
+ op, "illegal to collapse specified dimensions");
}
// Bail on non-canonical ranges.
SmallVector<Range> loopRanges =
- cast<LinalgOp>(genericOp.getOperation())
- .createLoopRanges(rewriter, genericOp.getLoc());
+ cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
return cast<IntegerAttr>(attr).getInt() == value;
@@ -1486,78 +1546,36 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
opFoldIsConstantValue(range.stride, 1);
})) {
return rewriter.notifyMatchFailure(
- genericOp,
- "expected all loop ranges to have zero start and unit stride");
+ op, "expected all loop ranges to have zero start and unit stride");
}
- // Get the iterator types for the operand.
- SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(
- genericOp.getIteratorTypesArray(), collapsingInfo);
-
- // Get the indexing maps.
- auto indexingMaps = llvm::to_vector(
- llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
- return getCollapsedOpIndexingMap(map, collapsingInfo);
- }));
-
- Location loc = genericOp->getLoc();
-
- // Get the input operands.
- auto inputOperands = llvm::to_vector(llvm::map_range(
- genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) {
- return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
- rewriter);
- }));
-
- // Get the output operands and result types.
- SmallVector<Type> resultTypes;
- SmallVector<Value> outputOperands;
- resultTypes.reserve(genericOp.getNumDpsInits());
- outputOperands.reserve(genericOp.getNumDpsInits());
- for (OpOperand &output : genericOp.getDpsInitsMutable()) {
- Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
- collapsingInfo, rewriter);
- outputOperands.push_back(newOutput);
- // If the op has "buffer semantics", then the init operands are ranked
- // memrefs and the op has no results.
- if (!hasBufferSemantics)
- resultTypes.push_back(newOutput.getType());
- }
-
- // Create the generic op.
- auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
- loc, resultTypes, inputOperands, outputOperands, indexingMaps,
- iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
- Block *origOpBlock = &genericOp->getRegion(0).front();
- Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
- rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
- collapsedOpBlock->getArguments());
+ LinalgType collapsedOp = cast<LinalgType>(
+ createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));
- if (collapsedGenericOp.hasIndexSemantics()) {
+ Location loc = op->getLoc();
+ if (collapsedOp.hasIndexSemantics()) {
// Collect the loop range of the generic op.
OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(collapsedGenericOp);
+ rewriter.setInsertionPoint(collapsedOp);
SmallVector<Value> loopBound =
- llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) {
+ llvm::map_to_vector(loopRanges, [&](Range range) {
return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
- }));
- generateCollapsedIndexingRegion(loc,
- &collapsedGenericOp->getRegion(0).front(),
+ });
+ generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
collapsingInfo, loopBound, rewriter);
}
// Insert expanding reshape for the result to get back the original result
// type.
SmallVector<Value> results;
- for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
- Value collapsedOpResult =
- collapsedGenericOp->getResult(originalResult.index());
+ for (const auto &originalResult : llvm::enumerate(op->getResults())) {
+ Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
auto originalResultType =
cast<ShapedType>(originalResult.value().getType());
auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
AffineMap indexingMap =
- genericOp.getIndexingMapMatchingResult(originalResult.value());
+ op.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
if (isa<MemRefType>(collapsedOpResult.getType())) {
@@ -1606,8 +1624,8 @@ class FoldWithProducerReshapeOpByCollapsing
}
std::optional<SmallVector<Value>> replacements =
- collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
- rewriter);
+ collapseOpIterationDims<linalg::GenericOp>(
+ genericOp, collapsableIterationDims, rewriter);
if (!replacements) {
return rewriter.notifyMatchFailure(
genericOp, "failed to do the fusion by collapsing transformation");
@@ -1624,36 +1642,36 @@ class FoldWithProducerReshapeOpByCollapsing
};
/// Pattern to collapse dimensions.
-class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
+template <typename LinalgType>
+class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
public:
CollapseLinalgDimensions(MLIRContext *context,
GetCollapsableDimensionsFn collapseDimensions,
PatternBenefit benefit = 1)
- : OpRewritePattern<GenericOp>(context, benefit),
+ : OpRewritePattern<LinalgType>(context, benefit),
controlCollapseDimension(std::move(collapseDimensions)) {}
- LogicalResult matchAndRewrite(GenericOp genericOp,
+ LogicalResult matchAndRewrite(LinalgType op,
PatternRewriter &rewriter) const override {
SmallVector<ReassociationIndices> collapsableIterationDims =
- controlCollapseDimension(genericOp);
+ controlCollapseDimension(op);
if (collapsableIterationDims.empty())
return failure();
// Check if the specified list of dimensions to collapse is a valid list.
- if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(),
+ if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
collapsableIterationDims)) {
return rewriter.notifyMatchFailure(
- genericOp, "specified dimensions cannot be collapsed");
+ op, "specified dimensions cannot be collapsed");
}
std::optional<SmallVector<Value>> replacements =
- collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
- rewriter);
+ collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
+ rewriter);
if (!replacements) {
- return rewriter.notifyMatchFailure(genericOp,
- "failed to collapse dimensions");
+ return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
}
- rewriter.replaceOp(genericOp, *replacements);
+ rewriter.replaceOp(op, *replacements);
return success();
}
@@ -1884,8 +1902,8 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
void mlir::linalg::populateCollapseDimensions(
RewritePatternSet &patterns,
const GetCollapsableDimensionsFn &controlCollapseDimensions) {
- patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
- controlCollapseDimensions);
+ patterns.add<CollapseLinalgDimensions<linalg::GenericOp>, CollapseLinalgDimensions<linalg::CopyOp>>(
+ patterns.getContext(), controlCollapseDimensions);
}
//===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 106154ba3a553bd..547320f53387477 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -116,3 +116,40 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
}
return %alloc : memref<2x6x24x48xi32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @linalg_copy(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
+// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
+// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
+// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
+// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
+// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
+// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
+// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK: return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK: }
+
+func.func @linalg_copy(
+ %arg0: tensor<1x2x3x4x5xf32, 1>, %arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3> {
+ %0 = linalg.copy ins(%arg0: tensor<1x2x3x4x5xf32, 1>) outs(%arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3>
+ return %0 : tensor<1x2x3x4x5xf32, 3>
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @memref_linalg_copy(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x24x32x8xf32, 1>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x24x32x8xf32, 1>) {
+// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
+// CHECK: linalg.copy ins(%[[VAL_2]] : memref<1x24x256xf32, 1>) outs(%[[VAL_3]] : memref<1x24x256xf32, 1>)
+// CHECK: return
+// CHECK: }
+
+func.func private @memref_linalg_copy(%arg0: memref<1x24x32x8xf32, 1>, %arg1: memref<1x24x32x8xf32, 1>) {
+ linalg.copy ins(%arg0: memref<1x24x32x8xf32, 1>) outs(%arg1: memref<1x24x32x8xf32, 1>)
+ return
+}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index e41481a9e51364e..7f68f4aec3a10c3 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -258,7 +258,7 @@ struct TestLinalgElementwiseFusion
SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
collapseDimensions.end());
linalg::GetCollapsableDimensionsFn collapseFn =
- [&dims](linalg::GenericOp op) {
+ [&dims](linalg::LinalgOp op) {
SmallVector<ReassociationIndices> reassociations;
reassociations.emplace_back(dims);
return reassociations;
More information about the Mlir-commits
mailing list