[Mlir-commits] [mlir] [mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp (PR #68526)
Aviad Cohen
llvmlistbot at llvm.org
Wed Oct 11 01:28:36 PDT 2023
================
@@ -1467,80 +1490,97 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
opFoldIsConstantValue(range.stride, 1);
})) {
return rewriter.notifyMatchFailure(
- genericOp,
- "expected all loop ranges to have zero start and unit stride");
+ op, "expected all loop ranges to have zero start and unit stride");
}
// Get the iterator types for the operand.
- SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(
- genericOp.getIteratorTypesArray(), collapsingInfo);
+ SmallVector<utils::IteratorType> iteratorTypes =
+ getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
// Get the indexing maps.
auto indexingMaps = llvm::to_vector(
- llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
+ llvm::map_range(op.getIndexingMapsArray(), [&](AffineMap map) {
return getCollapsedOpIndexingMap(map, collapsingInfo);
}));
- Location loc = genericOp->getLoc();
+ Location loc = op->getLoc();
// Get the input operands.
- auto inputOperands = llvm::to_vector(llvm::map_range(
- genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) {
- return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
+ auto inputOperands = llvm::to_vector(
+ llvm::map_range(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
+ return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
rewriter);
}));
// Get the output operands and result types.
SmallVector<Type> resultTypes;
SmallVector<Value> outputOperands;
- resultTypes.reserve(genericOp.getNumDpsInits());
- outputOperands.reserve(genericOp.getNumDpsInits());
- for (OpOperand &output : genericOp.getDpsInitsMutable()) {
- Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
- collapsingInfo, rewriter);
+ resultTypes.reserve(op.getNumDpsInits());
+ outputOperands.reserve(op.getNumDpsInits());
+ for (OpOperand &output : op.getDpsInitsMutable()) {
+ Value newOutput =
+ getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
outputOperands.push_back(newOutput);
- resultTypes.push_back(newOutput.getType());
+ // If the op has "buffer semantics", then the init operands are ranked
+ // memrefs and the op has no results.
+ if (!hasBufferSemantics)
+ resultTypes.push_back(newOutput.getType());
}
// Create the generic op.
- auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
- loc, resultTypes, inputOperands, outputOperands, indexingMaps,
- iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
- Block *origOpBlock = &genericOp->getRegion(0).front();
- Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
- rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
- collapsedOpBlock->getArguments());
-
- if (collapsedGenericOp.hasIndexSemantics()) {
+ Operation *collapsedOp;
----------------
AviadCo wrote:
I agree that this function is long, extracted relevant code to a c'tor function.
https://github.com/llvm/llvm-project/pull/68526
More information about the Mlir-commits
mailing list