[Mlir-commits] [mlir] [mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp (PR #68526)
Aviad Cohen
llvmlistbot at llvm.org
Sun Oct 8 06:30:27 PDT 2023
https://github.com/AviadCo created https://github.com/llvm/llvm-project/pull/68526
- [mlir][linalg] Enable CollapseLinalgDimensions to collapse memref based operations
- [mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp
>From 5e59fd1725a54a3ab5ab6f8d1cb46f3bb06fb8ce Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Fri, 6 Oct 2023 15:07:37 +0300
Subject: [PATCH 1/2] [mlir][linalg] Enable CollapseLinalgDimensions to
collapse memref based operations
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 42 ++++++++++++++---
mlir/test/Dialect/Linalg/collapse-dim.mlir | 46 +++++++++++++++++++
2 files changed, 81 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 069c613cc246d6a..6f4b0ff60ca97c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1388,9 +1388,15 @@ static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
return operand;
// Insert a reshape to collapse the dimensions.
- auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
- loc, operand, operandReassociation);
- return reshapeOp.getResult();
+ if (isa<MemRefType>(operand.getType())) {
+ return builder
+ .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
+ .getResult();
+ } else {
+ return builder
+ .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
+ .getResult();
+ }
}
/// Modify the `linalg.index` operations in the original generic op, to its
@@ -1444,6 +1450,19 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
}))
return failure();
+ bool hasBufferSemantics = genericOp.hasBufferSemantics();
+ if (hasBufferSemantics &&
+ !llvm::all_of(genericOp->getOperands(), [&](Value operand) -> bool {
+ MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
+ if (!memRefToCollapse)
+ return true;
+
+ return memref::CollapseShapeOp::isGuaranteedCollapsible(
+ memRefToCollapse, foldedIterationDims);
+ }))
+ return rewriter.notifyMatchFailure(genericOp,
+ "memref is not guaranteed collapsible");
+
CollapsingInfo collapsingInfo;
if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
foldedIterationDims))) {
@@ -1499,7 +1518,10 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
collapsingInfo, rewriter);
outputOperands.push_back(newOutput);
- resultTypes.push_back(newOutput.getType());
+ // If the op has "buffer semantics", then the init operands are ranked
+ // memrefs and the op has no results.
+ if (!hasBufferSemantics)
+ resultTypes.push_back(newOutput.getType());
}
// Create the generic op.
@@ -1538,9 +1560,15 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
genericOp.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
- Value result = rewriter.create<tensor::ExpandShapeOp>(
- loc, originalResultType, collapsedOpResult, reassociation);
- results.push_back(result);
+ if (isa<MemRefType>(collapsedOpResult.getType())) {
+ Value result = rewriter.create<memref::ExpandShapeOp>(
+ loc, originalResultType, collapsedOpResult, reassociation);
+ results.push_back(result);
+ } else {
+ Value result = rewriter.create<tensor::ExpandShapeOp>(
+ loc, originalResultType, collapsedOpResult, reassociation);
+ results.push_back(result);
+ }
} else {
results.push_back(collapsedOpResult);
}
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 6737a6e15da5afe..106154ba3a553bd 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -70,3 +70,49 @@ func.func @uncollapsable(%arg0 : tensor<41x3x1x57xf32>, %arg1 : tensor<3x1x57x41
// CHECK-LABEL: func @uncollapsable(
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+// -----
+
+// CHECK-LABEL: func.func private @collapsable_memref(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x24x32x8xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x24x32x8xf32>) -> memref<1x24x32x8xf32> {
+// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK: %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK: %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_3]], %[[VAL_4]] : memref<1x24x256xf32>, memref<1x24x256xf32>) outs(%[[VAL_5]] : memref<1x24x256xf32>) {
+// CHECK: ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
+// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
+// CHECK: linalg.yield %[[VAL_9]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_2]] : memref<1x24x32x8xf32>
+// CHECK: }
+
+func.func private @collapsable_memref(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>) -> (memref<1x24x32x8xf32>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+ linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%alloc : memref<1x24x32x8xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.addf %in, %in_0 : f32
+ linalg.yield %0 : f32
+ }
+ return %alloc : memref<1x24x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @uncollapsable_strided_memref(
+// CHECK: linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: memref<2x6x24x48xi32>) -> (memref<2x6x24x48xi32>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x6x24x48xi32>
+ %subview = memref.subview %arg0[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+ %subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+ %subview1 = memref.subview %alloc[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+ linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview, %subview0 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>, memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) outs(%subview1 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %0 = arith.addi %in, %in_0 : i32
+ linalg.yield %0 : i32
+ }
+ return %alloc : memref<2x6x24x48xi32>
+}
>From 17609b475479f2ac955ac92c6f8292e13d312eb8 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Sun, 8 Oct 2023 16:02:32 +0300
Subject: [PATCH 2/2] [mlir][linalg] Enable CollapseLinalgDimensions to
collapse linalg::CopyOp
---
.../Dialect/Linalg/Transforms/Transforms.h | 18 +--
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 130 ++++++++++--------
mlir/test/Dialect/Linalg/collapse-dim.mlir | 37 +++++
.../Linalg/TestLinalgElementwiseFusion.cpp | 2 +-
4 files changed, 120 insertions(+), 67 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 07a192f7b8606d3..0b0be116ce1c1d6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1047,16 +1047,18 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
ArrayRef<ReassociationIndices> dimSequences);
-/// Collapses dimensions of linalg.generic operation. A precondition to
-/// calling this method is that for each list in `foldedIterationDim`, the
+/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
+/// to calling this method is that for each list in `foldedIterationDim`, the
/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
-/// the `genericOp`. This can be checked using `areDimSequencePreserved` method.
+/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
/// When valid, the method also collapses the operands of the op. Returns
-/// replacement values of the results of the original `genericOp` by inserting
+/// replacement values of the results of the original `linalgOp` by inserting
/// reshapes to get back values of compatible types.
-FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
- GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
- RewriterBase &rewriter);
+template <typename LinalgType>
+FailureOr<SmallVector<Value>>
+collapseOpIterationDims(LinalgType op,
+ ArrayRef<ReassociationIndices> foldedIterationDims,
+ RewriterBase &rewriter);
struct LowerPackResult {
tensor::PadOp padOp;
@@ -1507,7 +1509,7 @@ void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns);
/// to return an array of `ReassociationIndices` representing dimensions that
/// should be merged.
using GetCollapsableDimensionsFn =
- std::function<SmallVector<ReassociationIndices>(linalg::GenericOp)>;
+ std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>;
/// Pattern to collapse dimensions in a linalg.generic op. This will collapse
/// tensor operands when needed and expand back the result tensors.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6f4b0ff60ca97c6..3e5f0ec24ffde99 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1373,16 +1373,17 @@ getOperandReassociation(AffineMap indexingMap,
}
/// Get the new value to use for a given `OpOperand` in the collapsed operation.
-static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
+static Value getCollapsedOpOperand(Location loc, LinalgOp op,
OpOperand *opOperand,
const CollapsingInfo &collapsingInfo,
OpBuilder &builder) {
- AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
+ AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
SmallVector<ReassociationIndices> operandReassociation =
getOperandReassociation(indexingMap, collapsingInfo);
- // If the number of entries in the reassocation for the operand is same as the
- // number of results of the indexing map, then nothing to do for this operand.
+ // If the number of entries in the reassociation for the operand is same as
+ // the number of results of the indexing map, then nothing to do for this
+ // operand.
Value operand = opOperand->get();
if (operandReassociation.size() == indexingMap.getNumResults())
return operand;
@@ -1440,19 +1441,23 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
}
/// Implementation of fusion with reshape operation by collapsing dimensions.
-FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
- GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
+template <typename LinalgType>
+FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
+ LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter) {
+ static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
+ "unsupported linalg op type to collapse");
+
// Bail on trivial no-op cases.
- if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
+ if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
return foldedDims.size() <= 1;
}))
return failure();
- bool hasBufferSemantics = genericOp.hasBufferSemantics();
+ bool hasBufferSemantics = op.hasBufferSemantics();
if (hasBufferSemantics &&
- !llvm::all_of(genericOp->getOperands(), [&](Value operand) -> bool {
+ !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
if (!memRefToCollapse)
return true;
@@ -1460,20 +1465,19 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
return memref::CollapseShapeOp::isGuaranteedCollapsible(
memRefToCollapse, foldedIterationDims);
}))
- return rewriter.notifyMatchFailure(genericOp,
+ return rewriter.notifyMatchFailure(op,
"memref is not guaranteed collapsible");
CollapsingInfo collapsingInfo;
- if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
- foldedIterationDims))) {
+ if (failed(
+ collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
return rewriter.notifyMatchFailure(
- genericOp, "illegal to collapse specified dimensions");
+ op, "illegal to collapse specified dimensions");
}
// Bail on non-canonical ranges.
SmallVector<Range> loopRanges =
- cast<LinalgOp>(genericOp.getOperation())
- .createLoopRanges(rewriter, genericOp.getLoc());
+ cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
return cast<IntegerAttr>(attr).getInt() == value;
@@ -1486,37 +1490,36 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
opFoldIsConstantValue(range.stride, 1);
})) {
return rewriter.notifyMatchFailure(
- genericOp,
- "expected all loop ranges to have zero start and unit stride");
+ op, "expected all loop ranges to have zero start and unit stride");
}
// Get the iterator types for the operand.
- SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(
- genericOp.getIteratorTypesArray(), collapsingInfo);
+ SmallVector<utils::IteratorType> iteratorTypes =
+ getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
// Get the indexing maps.
auto indexingMaps = llvm::to_vector(
- llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
+ llvm::map_range(op.getIndexingMapsArray(), [&](AffineMap map) {
return getCollapsedOpIndexingMap(map, collapsingInfo);
}));
- Location loc = genericOp->getLoc();
+ Location loc = op->getLoc();
// Get the input operands.
- auto inputOperands = llvm::to_vector(llvm::map_range(
- genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) {
- return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
+ auto inputOperands = llvm::to_vector(
+ llvm::map_range(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
+ return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
rewriter);
}));
// Get the output operands and result types.
SmallVector<Type> resultTypes;
SmallVector<Value> outputOperands;
- resultTypes.reserve(genericOp.getNumDpsInits());
- outputOperands.reserve(genericOp.getNumDpsInits());
- for (OpOperand &output : genericOp.getDpsInitsMutable()) {
- Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
- collapsingInfo, rewriter);
+ resultTypes.reserve(op.getNumDpsInits());
+ outputOperands.reserve(op.getNumDpsInits());
+ for (OpOperand &output : op.getDpsInitsMutable()) {
+ Value newOutput =
+ getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
outputOperands.push_back(newOutput);
// If the op has "buffer semantics", then the init operands are ranked
// memrefs and the op has no results.
@@ -1525,39 +1528,48 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
}
// Create the generic op.
- auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
- loc, resultTypes, inputOperands, outputOperands, indexingMaps,
- iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
- Block *origOpBlock = &genericOp->getRegion(0).front();
- Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
- rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
- collapsedOpBlock->getArguments());
-
- if (collapsedGenericOp.hasIndexSemantics()) {
+ Operation *collapsedOp;
+ if (isa<linalg::GenericOp>(op)) {
+ collapsedOp = rewriter.create<linalg::GenericOp>(
+ loc, resultTypes, inputOperands, outputOperands, indexingMaps,
+ iteratorTypes,
+ [](OpBuilder &builder, Location loc, ValueRange args) {});
+ Block *origOpBlock = &op->getRegion(0).front();
+ Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
+ rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
+ collapsedOpBlock->getArguments());
+ } else {
+ assert(isa<linalg::CopyOp>(op));
+ collapsedOp = rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
+ outputOperands[0]);
+ }
+ LinalgType collapsedLinalgOp = cast<LinalgType>(collapsedOp);
+
+ if (collapsedLinalgOp.hasIndexSemantics()) {
// Collect the loop range of the generic op.
OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(collapsedGenericOp);
+ rewriter.setInsertionPoint(collapsedLinalgOp);
SmallVector<Value> loopBound =
llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) {
return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
}));
generateCollapsedIndexingRegion(loc,
- &collapsedGenericOp->getRegion(0).front(),
+ &collapsedLinalgOp->getRegion(0).front(),
collapsingInfo, loopBound, rewriter);
}
// Insert expanding reshape for the result to get back the original result
// type.
SmallVector<Value> results;
- for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
+ for (const auto &originalResult : llvm::enumerate(op->getResults())) {
Value collapsedOpResult =
- collapsedGenericOp->getResult(originalResult.index());
+ collapsedLinalgOp->getResult(originalResult.index());
auto originalResultType =
cast<ShapedType>(originalResult.value().getType());
auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
AffineMap indexingMap =
- genericOp.getIndexingMapMatchingResult(originalResult.value());
+ op.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
if (isa<MemRefType>(collapsedOpResult.getType())) {
@@ -1606,8 +1618,8 @@ class FoldWithProducerReshapeOpByCollapsing
}
std::optional<SmallVector<Value>> replacements =
- collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
- rewriter);
+ collapseOpIterationDims<linalg::GenericOp>(
+ genericOp, collapsableIterationDims, rewriter);
if (!replacements) {
return rewriter.notifyMatchFailure(
genericOp, "failed to do the fusion by collapsing transformation");
@@ -1624,36 +1636,36 @@ class FoldWithProducerReshapeOpByCollapsing
};
/// Pattern to collapse dimensions.
-class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
+template <typename LinalgType>
+class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
public:
CollapseLinalgDimensions(MLIRContext *context,
GetCollapsableDimensionsFn collapseDimensions,
PatternBenefit benefit = 1)
- : OpRewritePattern<GenericOp>(context, benefit),
+ : OpRewritePattern<LinalgType>(context, benefit),
controlCollapseDimension(std::move(collapseDimensions)) {}
- LogicalResult matchAndRewrite(GenericOp genericOp,
+ LogicalResult matchAndRewrite(LinalgType op,
PatternRewriter &rewriter) const override {
SmallVector<ReassociationIndices> collapsableIterationDims =
- controlCollapseDimension(genericOp);
+ controlCollapseDimension(op);
if (collapsableIterationDims.empty())
return failure();
// Check if the specified list of dimensions to collapse is a valid list.
- if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(),
+ if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
collapsableIterationDims)) {
return rewriter.notifyMatchFailure(
- genericOp, "specified dimensions cannot be collapsed");
+ op, "specified dimensions cannot be collapsed");
}
std::optional<SmallVector<Value>> replacements =
- collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
- rewriter);
+ collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
+ rewriter);
if (!replacements) {
- return rewriter.notifyMatchFailure(genericOp,
- "failed to collapse dimensions");
+ return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
}
- rewriter.replaceOp(genericOp, *replacements);
+ rewriter.replaceOp(op, *replacements);
return success();
}
@@ -1884,8 +1896,10 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
void mlir::linalg::populateCollapseDimensions(
RewritePatternSet &patterns,
const GetCollapsableDimensionsFn &controlCollapseDimensions) {
- patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
- controlCollapseDimensions);
+ patterns.add<CollapseLinalgDimensions<linalg::GenericOp>>(
+ patterns.getContext(), controlCollapseDimensions);
+ patterns.add<CollapseLinalgDimensions<linalg::CopyOp>>(
+ patterns.getContext(), controlCollapseDimensions);
}
//===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 106154ba3a553bd..547320f53387477 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -116,3 +116,40 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
}
return %alloc : memref<2x6x24x48xi32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @linalg_copy(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
+// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
+// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
+// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
+// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
+// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
+// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
+// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK: return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK: }
+
+func.func @linalg_copy(
+ %arg0: tensor<1x2x3x4x5xf32, 1>, %arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3> {
+ %0 = linalg.copy ins(%arg0: tensor<1x2x3x4x5xf32, 1>) outs(%arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3>
+ return %0 : tensor<1x2x3x4x5xf32, 3>
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @memref_linalg_copy(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x24x32x8xf32, 1>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x24x32x8xf32, 1>) {
+// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
+// CHECK: linalg.copy ins(%[[VAL_2]] : memref<1x24x256xf32, 1>) outs(%[[VAL_3]] : memref<1x24x256xf32, 1>)
+// CHECK: return
+// CHECK: }
+
+func.func private @memref_linalg_copy(%arg0: memref<1x24x32x8xf32, 1>, %arg1: memref<1x24x32x8xf32, 1>) {
+ linalg.copy ins(%arg0: memref<1x24x32x8xf32, 1>) outs(%arg1: memref<1x24x32x8xf32, 1>)
+ return
+}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index e41481a9e51364e..7f68f4aec3a10c3 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -258,7 +258,7 @@ struct TestLinalgElementwiseFusion
SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
collapseDimensions.end());
linalg::GetCollapsableDimensionsFn collapseFn =
- [&dims](linalg::GenericOp op) {
+ [&dims](linalg::LinalgOp op) {
SmallVector<ReassociationIndices> reassociations;
reassociations.emplace_back(dims);
return reassociations;
More information about the Mlir-commits
mailing list