[Mlir-commits] [mlir] f84b908 - [mlir][linalg] Cleanup LinalgOp usage in fusion on tensors (NFC).
Tobias Gysi
llvmlistbot at llvm.org
Wed Jun 2 05:21:31 PDT 2021
Author: Tobias Gysi
Date: 2021-06-02T12:20:45Z
New Revision: f84b908f89af76002112acbf915ab0677b99c01c
URL: https://github.com/llvm/llvm-project/commit/f84b908f89af76002112acbf915ab0677b99c01c
DIFF: https://github.com/llvm/llvm-project/commit/f84b908f89af76002112acbf915ab0677b99c01c.diff
LOG: [mlir][linalg] Cleanup LinalgOp usage in fusion on tensors (NFC).
Replace the uses of deprecated Structured Op Interface methods in FusionOnTensors.cpp. This patch is based on https://reviews.llvm.org/D103394.
Differential Revision: https://reviews.llvm.org/D103471
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 6ee4d765d5f8d..9b2292f46c3a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -28,7 +28,7 @@ using namespace mlir::linalg;
/// Conditions for elementwise fusion of generic operations.
static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
- unsigned consumerIdx) {
+ OpOperand *consumerOpOperand) {
// Producer and consumer must have tensor semantics.
if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
return false;
@@ -40,12 +40,12 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
// Only allow fusing the producer of an input operand for now.
// TODO: allow fusing the producer of an output operand.
- if (consumerIdx >= consumer.getNumInputs())
+ if (!consumer.isInputTensor(consumerOpOperand))
return false;
// Get the consumer index map. The number of results of the consumer index
// map must match the number of loops of the producer.
- AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
+ AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
if (consumerIndexMap.getNumResults() != producer.getNumLoops())
return false;
@@ -55,7 +55,8 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
// Finally the index_map for the result must be invertible. For now just
// verify it is a permutation.
- AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
+ AffineMap producerResultIndexMap =
+ producer.getTiedIndexingMap(producer.getOutputOperand(0));
return producerResultIndexMap.isPermutation();
}
@@ -63,7 +64,7 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
/// the `producer` to use in the fused operation given the indexing map of the
/// result of the producer in the consumer.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
- OpOperand &producerOpOperand, AffineMap producerResultIndexMap,
+ OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
AffineMap fusedConsumerArgIndexMap) {
// The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
// from consumer loop -> consumer arg tensor index/producer result tensor
@@ -78,10 +79,9 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
assert(invProducerResultIndexMap &&
"expected producer result indexig map to be invertible");
- LinalgOp producer = cast<LinalgOp>(producerOpOperand.getOwner());
+ LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
// argMap is a map from producer loop -> producer arg tensor index.
- AffineMap argMap =
- producer.getIndexingMap(producerOpOperand.getOperandNumber());
+ AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
// Compose argMap with invProducerResultIndexMap to get a map from
// producer result tensor index -> producer arg tensor index.
@@ -96,9 +96,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
/// op must be empty.
static void
generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
- GenericOp producer, GenericOp consumer,
AffineMap consumerToProducerLoopsMap,
- unsigned consumerIdx, unsigned nloops) {
+ OpOperand *consumerOpOperand,
+ unsigned nloops) {
+ auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
+ auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
// Build the region of the fused op.
Block &producerBlock = producer->getRegion(0).front();
Block &consumerBlock = consumer->getRegion(0).front();
@@ -129,11 +131,11 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
}
}
// TODO: allow fusing the producer of an output operand.
- assert(consumerIdx < consumer.getNumInputs() &&
+ assert(consumer.isInputTensor(consumerOpOperand) &&
"expected producer of input operand");
// 3. Consumer input operands up to consumerIdx (exclusive).
for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
- consumerIdx)) // input assumption.
+ consumerOpOperand->getOperandNumber())) // input assumption.
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
// Replacing consumerIdx requires getting the cloned, yielded, value from
@@ -147,7 +149,7 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
// 4.b. Producer output operand/map that is fused needs to be mapped to the
// producer bbArg if it is an "initTensor" (i.e. its value is actually read).
assert(producer->getNumResults() == 1 && "expected single result producer");
- if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
+ if (producer.isInitTensor(producer.getOutputOperand(0))) {
BlockArgument bbArg = producerBlock.getArguments()
.drop_front(producer.getNumInputs())
// TODO: bbArg index of
@@ -155,9 +157,10 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
}
// 5. Remaining consumer's input operands (drop past index `consumerIdx`).
- for (BlockArgument bbArg : consumerBlock.getArguments()
- .take_front(consumer.getNumInputs())
- .drop_front(consumerIdx + 1))
+ for (BlockArgument bbArg :
+ consumerBlock.getArguments()
+ .take_front(consumer.getNumInputs())
+ .drop_front(consumerOpOperand->getOperandNumber() + 1))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
// 6. All of consumer's output operands.
for (BlockArgument bbArg :
@@ -191,7 +194,8 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
assert(!producer->isAncestor(replacement.getDefiningOp()) &&
"yielded value must have been mapped");
}
- mapper.map(consumerBlock.getArgument(consumerIdx), replacement);
+ mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
+ replacement);
// 10. Clone operations from the consumer to the fused op.
for (auto &op : consumerBlock.getOperations())
rewriter.clone(op, mapper);
@@ -202,17 +206,16 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
}
static Optional<SmallVector<Value>>
-fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand,
+fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
const ControlElementwiseOpsFusionFn &controlFn,
PatternRewriter &rewriter) {
- auto consumer = cast<GenericOp>(consumerOpOperand.getOwner());
- unsigned consumerIdx = consumerOpOperand.getOperandNumber();
- if (!areElementwiseOpsFusable(producer, consumer, consumerIdx) ||
- !controlFn(producer->getResult(0), consumerOpOperand))
+ auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
+ if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
+ !controlFn(producer->getResult(0), *consumerOpOperand))
return llvm::None;
// TODO: allow fusing the producer of an output operand.
- assert(consumerIdx < consumer.getNumInputs() &&
+ assert(consumer.isInputTensor(consumerOpOperand) &&
"expected producer of input operand");
// Compute the fused operands list and indexing maps.
@@ -224,62 +227,66 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand,
consumer->getNumOperands());
// In the following, numbering matches that of `generateFusedTensorOpRegion`.
// 3. Consumer input operands/maps up to consumerIdx (exclusive).
- llvm::append_range(fusedOperands,
- consumer.getInputs().take_front(consumerIdx));
- llvm::append_range(
- fusedIndexMaps,
- ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.take_front(
- consumerIdx));
+ SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
+ SmallVector<OpOperand *>::iterator it =
+ llvm::find(consumerInputs, consumerOpOperand);
+ assert(it != consumerInputs.end() && "expected to find the consumer operand");
+ for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
+ fusedOperands.push_back(opOperand->get());
+ fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
+ }
// 4. Splice in producer's input operands/maps.
- llvm::append_range(fusedOperands, producer.getInputs());
assert(producer->getNumResults() == 1 && "expected single result producer");
- AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
- for (auto &inputOpOperand : producer.getInputOpOperands()) {
+ AffineMap producerResultIndexMap =
+ producer.getTiedIndexingMap(producer.getOutputOperand(0));
+ for (OpOperand *opOperand : producer.getInputOperands()) {
+ fusedOperands.push_back(opOperand->get());
// Compute indexing maps for the producer args in the fused operation.
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
- inputOpOperand, producerResultIndexMap,
- consumer.getInputIndexingMap(consumerIdx));
+ opOperand, producerResultIndexMap,
+ consumer.getTiedIndexingMap(consumerOpOperand));
fusedIndexMaps.push_back(map);
}
// 4.b. Producer output operand/map that is fused needs to be passed if it is
// an "initTensor" (i.e. its value is actually read).
assert(producer->getNumResults() == 1 && "expected single result producer");
- if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
- llvm::append_range(fusedOperands, producer.getOutputs().take_front());
+ if (producer.isInitTensor(producer.getOutputOperand(0))) {
+ fusedOperands.push_back(producer.getOutputOperand(0)->get());
// Compute indexing maps for the producer args in the fused operation.
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
- producer.getOutputOpOperands().front(), producerResultIndexMap,
- consumer.getOutputIndexingMap(0));
+ producer.getOutputOperand(0), producerResultIndexMap,
+ consumer.getTiedIndexingMap(consumerOpOperand));
fusedIndexMaps.push_back(map);
}
// 5. Remaining consumer's input operands/maps (drop past index
// `consumerIdx`).
- llvm::append_range(fusedOperands,
- consumer.getInputs().drop_front(consumerIdx + 1));
- llvm::append_range(
- fusedIndexMaps,
- ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.drop_front(
- consumerIdx + 1));
+ for (OpOperand *opOperand :
+ llvm::make_range(std::next(it), consumerInputs.end())) {
+ fusedOperands.push_back(opOperand->get());
+ fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
+ }
// 6. All of consumer's output operands (skip operands: added by the builder).
- // llvm::append_range(fusedOperands, consumer.getOutputs());
- llvm::append_range(fusedIndexMaps, consumer.getOutputIndexingMaps());
+ for (OpOperand *opOperand : consumer.getOutputOperands())
+ fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
// 7. All of producer's output operands/maps except the one fused.
// TODO: allow fusion of multi-result producers.
assert(producer->getNumResults() == 1 && "expected single result producer");
// Generate the fused op.
+ SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
auto fusedOp = rewriter.create<GenericOp>(
consumer.getLoc(), consumer->getResultTypes(),
/*inputs=*/fusedOperands,
// TODO: handle outputs.
- consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+ consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
consumer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr);
// Construct an AffineMap from consumer loops to producer loops.
// consumer loop -> tensor index
- AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx);
+ AffineMap consumerResultIndexMap =
+ consumer.getTiedIndexingMap(consumerOpOperand);
// tensor index -> producer loop
AffineMap invProducerResultIndexMap =
inversePermutation(producerResultIndexMap);
@@ -289,9 +296,9 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand,
AffineMap consumerToProducerLoopsMap =
invProducerResultIndexMap.compose(consumerResultIndexMap);
- generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer,
- consumerToProducerLoopsMap, consumerIdx,
- consumer.getNumLoops());
+ generateFusedElementwiseOpRegion(rewriter, fusedOp,
+ consumerToProducerLoopsMap,
+ consumerOpOperand, consumer.getNumLoops());
return SmallVector<Value>(fusedOp->getResults());
}
@@ -449,7 +456,7 @@ static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
/// The added reshapes are again expanding patterns, so they will get fused
/// with its producers if possible.
static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
- unsigned fusedTensorIndex) {
+ OpOperand *fusableOpOperand) {
// Is fusable only if:
// - All the indexing maps for operands and results are projected
// permutations.
@@ -462,7 +469,7 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
.getValue()
.isProjectedPermutation();
}) &&
- genericOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 &&
+ genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
return attr.cast<StringAttr>().getValue() ==
getParallelIteratorTypeName();
@@ -478,7 +485,7 @@ class ExpansionInfo {
// of the expanded op given the `indexingMap` of the fused operand/result of
// the generic op, the `reassocationMaps` of the reshape op and the shape of
// the expanded op.
- LogicalResult compute(LinalgOp linalgOp, unsigned fusedTensorIndex,
+ LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
ArrayRef<int64_t> expandedShape,
PatternRewriter &rewriter);
@@ -503,13 +510,13 @@ class ExpansionInfo {
} // namespace
LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
- unsigned fusedTensorIndex,
+ OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
ArrayRef<int64_t> expandedShape,
PatternRewriter &rewriter) {
if (reassociationMaps.empty())
return failure();
- AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
+ AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
Optional<SmallVector<int64_t, 4>> originalLoopRange =
linalgOp.getStaticLoopRanges();
@@ -676,9 +683,9 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
/// been satisfied.
static Optional<SmallVector<Value>>
fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
- unsigned fusedTensorIndex,
+ OpOperand *fusableOpOperand,
PatternRewriter &rewriter) {
- assert(isFusableWithReshapeByDimExpansion(genericOp, fusedTensorIndex) &&
+ assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
"preconditions for fuse operation failed");
// Check if reshape is expanding or collapsing.
bool isExpanding =
@@ -687,7 +694,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
ExpansionInfo expansionInfo;
- if (failed(expansionInfo.compute(genericOp, fusedTensorIndex,
+ if (failed(expansionInfo.compute(genericOp, fusableOpOperand,
reshapeOp.getReassociationMaps(),
expandedType.getShape(), rewriter)))
return llvm::None;
@@ -701,39 +708,39 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
}));
SmallVector<Value> expandedOpOperands;
- for (auto operand : llvm::enumerate(genericOp.getInputs())) {
- if (operand.index() == fusedTensorIndex) {
+ for (OpOperand *opOperand : genericOp.getInputOperands()) {
+ if (opOperand == fusableOpOperand) {
expandedOpOperands.push_back(reshapeOp.src());
continue;
}
- AffineMap indexingMap = genericOp.getInputIndexingMap(operand.index());
+ AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
RankedTensorType expandedOperandType =
- getExpandedType(operand.value().getType().cast<RankedTensorType>(),
+ getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
indexingMap, expansionInfo);
- if (expandedOperandType != operand.value().getType()) {
+ if (expandedOperandType != opOperand->get().getType()) {
// Reshape the operand to get the right type.
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
expandedOpOperands.push_back(rewriter.create<TensorReshapeOp>(
- genericOp.getLoc(), expandedOperandType, operand.value(),
+ genericOp.getLoc(), expandedOperandType, opOperand->get(),
reassociation));
continue;
}
- expandedOpOperands.push_back(operand.value());
+ expandedOpOperands.push_back(opOperand->get());
}
Location loc = genericOp.getLoc();
SmallVector<Value> outputs;
- for (auto result : llvm::enumerate(genericOp.getOutputs())) {
- AffineMap indexingMap = genericOp.getOutputIndexingMap(result.index());
+ for (OpOperand *opOperand : genericOp.getOutputOperands()) {
+ AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
RankedTensorType expandedOutputType =
- getExpandedType(result.value().getType().cast<RankedTensorType>(),
+ getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
indexingMap, expansionInfo);
- if (expandedOutputType != result.value().getType()) {
+ if (expandedOutputType != opOperand->get().getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
outputs.push_back(rewriter.create<TensorReshapeOp>(
- genericOp.getLoc(), expandedOutputType, result.value(),
+ genericOp.getLoc(), expandedOutputType, opOperand->get(),
reassociation));
}
}
@@ -757,17 +764,19 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
SmallVector<Value> resultVals;
- for (auto result : llvm::enumerate(genericOp->getResults())) {
- if (!isExpanding &&
- resultTypes[result.index()] != result.value().getType()) {
+ for (OpResult opResult : genericOp->getOpResults()) {
+ int64_t resultNumber = opResult.getResultNumber();
+ if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(
- genericOp.getOutputIndexingMap(result.index()), expansionInfo);
+ genericOp.getTiedIndexingMap(
+ genericOp.getOutputOperand(resultNumber)),
+ expansionInfo);
resultVals.push_back(rewriter.create<TensorReshapeOp>(
- genericOp.getLoc(), result.value().getType(),
- fusedOp->getResult(result.index()), reassociation));
+ genericOp.getLoc(), opResult.getType(),
+ fusedOp->getResult(resultNumber), reassociation));
} else {
- resultVals.push_back(fusedOp->getResult(result.index()));
+ resultVals.push_back(fusedOp->getResult(resultNumber));
}
}
// Assuming a single result.
@@ -809,12 +818,13 @@ struct FoldProducerReshapeOpByLinearization
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
return failure();
- for (auto operand : llvm::enumerate(genericOp.getInputs())) {
+ SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
+ for (auto en : llvm::enumerate(inputOperands)) {
TensorReshapeOp reshapeOp =
- operand.value().getDefiningOp<TensorReshapeOp>();
+ en.value()->get().getDefiningOp<TensorReshapeOp>();
if (!reshapeOp ||
!isTensorReshapeOpFoldableByLinearization(
- reshapeOp, genericOp.getInputIndexingMap(operand.index()),
+ reshapeOp, genericOp.getTiedIndexingMap(en.value()),
/*asProducer =*/true) ||
(foldUnitDimReshapesOnly &&
!isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
@@ -822,18 +832,17 @@ struct FoldProducerReshapeOpByLinearization
continue;
// Compute the fused operands list,
- SmallVector<Value> fusedOperands(genericOp.getInputs());
- fusedOperands[operand.index()] = reshapeOp.src();
- fusedOperands.append(genericOp.getOutputs().begin(),
- genericOp.getOutputs().end());
+ SmallVector<Value> fusedOperands = genericOp.getInputOperands();
+ fusedOperands[en.index()] = reshapeOp.src();
+ SmallVector<Value> outputOperands = genericOp.getOutputOperands();
+ llvm::append_range(fusedOperands, outputOperands);
// Compute indexing_maps for the fused operation. The indexing_maps for
// the operands of the consumers that arent fused are the same.
- SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
- genericOp.indexing_maps().template getAsValueRange<AffineMapAttr>());
+ SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
// Accepted consumer maps are either identity or permutation.
- auto invMap = inversePermutation(fusedIndexMaps[operand.index()]);
+ auto invMap = inversePermutation(fusedIndexMaps[en.index()]);
// Compute the indexing map to use for the result of the producer.
AffineMap modifiedMap =
@@ -843,7 +852,7 @@ struct FoldProducerReshapeOpByLinearization
if (!expr.isPureAffine())
return failure();
}
- fusedIndexMaps[operand.index()] = modifiedMap;
+ fusedIndexMaps[en.index()] = modifiedMap;
// Further check that the resulting index maps can be fused and
// inverted. Without this the resultant op is not legal.
@@ -917,35 +926,36 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
return failure();
// Only support identity output maps. It could be extended to permuations if
// needed.
- if (llvm::any_of(genericOp.getOutputIndexingMaps(),
- [](AffineMap map) { return !map.isIdentity(); }))
+ if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
+ return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
+ }))
return failure();
int64_t destRank = genericOp.getNumParallelLoops();
- SmallVector<Value, 4> newOperands =
- llvm::to_vector<4>(genericOp.getInputs());
+ SmallVector<Value> newOperands = genericOp.getInputOperands();
TensorReshapeOp reshapeFound;
// 1. Look for tensor_reshape operands and figure out save the dimensions
// merged.
- for (auto operand : llvm::enumerate(genericOp.getInputs())) {
+ SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
+ for (auto en : llvm::enumerate(inputOperands)) {
TensorReshapeOp reshapeOp =
- operand.value().template getDefiningOp<TensorReshapeOp>();
+ en.value()->get().template getDefiningOp<TensorReshapeOp>();
if (!reshapeOp || reshapeOp.getSrcType().getRank() >
reshapeOp.getResultType().getRank()) {
continue;
}
// TODO: We could support non-identity map as long as the merged
// dimensions are still contiguous.
- if (!genericOp.getIndexingMaps()[operand.index()].isIdentity())
+ if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
continue;
if (reshapeFound) {
// Only support a second reshape op if it has the same reassociate maps.
if (reshapeFound.getReassociationMaps() ==
reshapeOp.getReassociationMaps())
- newOperands[operand.index()] = reshapeOp.src();
+ newOperands[en.index()] = reshapeOp.src();
continue;
}
reshapeFound = reshapeOp;
- newOperands[operand.index()] = reshapeOp.src();
+ newOperands[en.index()] = reshapeOp.src();
}
if (!reshapeFound)
return failure();
@@ -962,9 +972,9 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
// 2. Verify that we can merge the dimensions in the linalg and that we
// don't need to create new reshapes operands. Inserting new reshape
// operands would defeat the purpose of the transformation.
- for (auto operand : llvm::enumerate(genericOp.getInputs())) {
- if (operand.value() == newOperands[operand.index()]) {
- AffineMap map = genericOp.getIndexingMaps()[operand.index()];
+ for (auto en : llvm::enumerate(inputOperands)) {
+ if (en.value()->get() == newOperands[en.index()]) {
+ AffineMap map = genericOp.getTiedIndexingMap(en.value());
for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
return failure();
@@ -1036,9 +1046,9 @@ class FoldWithProducerReshapeOpByExpansion
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- for (auto operand : llvm::enumerate(genericOp.getInputs())) {
+ for (OpOperand *opOperand : genericOp.getInputOperands()) {
TensorReshapeOp reshapeOp =
- operand.value().getDefiningOp<TensorReshapeOp>();
+ opOperand->get().getDefiningOp<TensorReshapeOp>();
if (!reshapeOp)
continue;
// Fold only if
@@ -1046,15 +1056,12 @@ class FoldWithProducerReshapeOpByExpansion
// - All constraints of fusing with reshape by expansion are met.
if (reshapeOp.getSrcType().getRank() <
reshapeOp.getResultType().getRank() ||
- !isFusableWithReshapeByDimExpansion(genericOp, operand.index()) ||
- (!controlFoldingReshapes(
- reshapeOp->getResult(0),
- genericOp.getInputOpOperands()[operand.index()])))
+ !isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
+ (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
continue;
Optional<SmallVector<Value>> replacementValues =
- fuseWithReshapeByExpansion(genericOp, reshapeOp, operand.index(),
- rewriter);
+ fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
if (!replacementValues)
return failure();
rewriter.replaceOp(genericOp, replacementValues.getValue());
@@ -1080,7 +1087,8 @@ struct FoldConsumerReshapeOpByLinearization
if (!producer || !producer.hasTensorSemantics() ||
producer.getNumOutputs() != 1 ||
!isTensorReshapeOpFoldableByLinearization(
- reshapeOp, producer.getOutputIndexingMap(0),
+ reshapeOp,
+ producer.getTiedIndexingMap(producer.getOutputOperand(0)),
/*asProducer =*/false) ||
(foldUnitDimReshapesOnly &&
!isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
@@ -1088,10 +1096,10 @@ struct FoldConsumerReshapeOpByLinearization
return failure();
// The indexing_maps for the operands of the fused operation are same as
// those for the operands of the producer.
- SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
- producer.indexing_maps().getAsValueRange<AffineMapAttr>());
+ SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
- auto invMap = inversePermutation(producer.getOutputIndexingMap(0));
+ auto invMap = inversePermutation(
+ producer.getTiedIndexingMap(producer.getOutputOperand(0)));
// Compute the indexing map to use for the operand of the producer.
AffineMap modifiedMap =
@@ -1113,11 +1121,13 @@ struct FoldConsumerReshapeOpByLinearization
}
Location loc = producer.getLoc();
+ SmallVector<Value> inputOperands = producer.getInputOperands();
Value output = rewriter.create<TensorReshapeOp>(
- loc, producer.getOutputs()[0], reshapeOp.getReassociationExprs());
+ loc, producer.getOutputOperand(0)->get(),
+ reshapeOp.getReassociationExprs());
auto fusedOp = rewriter.create<GenericOp>(
loc, reshapeOp.getResultType(),
- /*inputs=*/producer.getInputs(),
+ /*inputs=*/inputOperands,
// TODO: handle outputs.
/*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
producer.iterator_types(),
@@ -1147,12 +1157,12 @@ struct FoldReshapeWithGenericOpByExpansion
GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
if (!producer || producer.getNumOutputs() != 1 ||
!isFusableWithReshapeByDimExpansion(producer,
- producer.getNumInputs()) ||
+ producer.getOutputOperand(0)) ||
isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
reshapeOp.getReassociationMaps()))
return failure();
Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
- producer, reshapeOp, producer.getNumInputs(), rewriter);
+ producer, reshapeOp, producer.getOutputOperand(0), rewriter);
if (!replacementValues)
return failure();
rewriter.replaceOp(reshapeOp, replacementValues.getValue());
@@ -1171,21 +1181,29 @@ class FoldSplatConstants : public OpRewritePattern<GenericOp> {
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
return failure();
- for (auto operand : llvm::enumerate(genericOp.getInputOpOperands())) {
- Operation *def = operand.value().get().getDefiningOp();
+ for (OpOperand *opOperand : genericOp.getInputOperands()) {
+ Operation *def = opOperand->get().getDefiningOp();
DenseElementsAttr constantAttr;
if (!def ||
!matchPattern(def, m_Constant<DenseElementsAttr>(&constantAttr)) ||
- !constantAttr.isSplat() ||
- !controlFn(def->getResult(0), operand.value()))
+ !constantAttr.isSplat() || !controlFn(def->getResult(0), *opOperand))
continue;
- // The indexing_maps for the operands of the fused operation are same as
- // those for the operands of the genericOp without the indexing map at
- // operand.index()
- SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
- genericOp.indexing_maps().getAsValueRange<AffineMapAttr>());
- fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index()));
+ // The operands and the indexing_maps of the fused operation the same as
+ // the operands and indexing_maps of the generic operations with the
+ // values at the constant index dropped.
+ SmallVector<AffineMap> fusedIndexMaps;
+ SmallVector<Value> fusedOperands;
+ fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
+ fusedOperands.reserve(genericOp.getNumInputs());
+ for (OpOperand *inputOperand : genericOp.getInputOperands()) {
+ if (inputOperand == opOperand)
+ continue;
+ fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
+ fusedOperands.push_back(inputOperand->get());
+ }
+ for (OpOperand *outputOperand : genericOp.getOutputOperands())
+ fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
// Check if the operation shapes to loops map is computable.
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
@@ -1193,20 +1211,16 @@ class FoldSplatConstants : public OpRewritePattern<GenericOp> {
genericOp, "fused op loop bound computation failed");
}
- // The operands list is same as the genericOp with the argument for
- // constant index dropped.
- SmallVector<Value> fusedOperands(genericOp.getInputs());
- fusedOperands.erase(std::next(fusedOperands.begin(), operand.index()));
-
// Create a constant scalar value from the splat constant.
Value scalarConstant = rewriter.create<ConstantOp>(
def->getLoc(), constantAttr.getSplatValue(),
constantAttr.getType().getElementType());
+ SmallVector<Value> outputOperands = genericOp.getOutputOperands();
auto fusedOp = rewriter.create<GenericOp>(
rewriter.getUnknownLoc(), genericOp->getResultTypes(),
/*inputs=*/fusedOperands,
- /*outputs=*/genericOp.getOutputs(),
+ /*outputs=*/outputOperands,
rewriter.getAffineMapArrayAttr(fusedIndexMaps),
genericOp.iterator_types(),
/*doc=*/nullptr,
@@ -1217,7 +1231,8 @@ class FoldSplatConstants : public OpRewritePattern<GenericOp> {
Region ®ion = genericOp->getRegion(0);
Block &entryBlock = *region.begin();
BlockAndValueMapping mapping;
- mapping.map(entryBlock.getArgument(operand.index()), scalarConstant);
+ mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
+ scalarConstant);
Region &fusedRegion = fusedOp->getRegion(0);
rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
mapping);
@@ -1233,7 +1248,7 @@ class FoldSplatConstants : public OpRewritePattern<GenericOp> {
} // namespace
static Optional<SmallVector<Value>>
-fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand,
+fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
GenericOp producer,
const ControlElementwiseOpsFusionFn &controlFn) {
if (producer->getNumResults() != 1)
@@ -1261,9 +1276,9 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
- for (OpOperand &opOperand : genericOp.getShapedOpOperands()) {
+ for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
auto producer =
- dyn_cast_or_null<GenericOp>(opOperand.get().getDefiningOp());
+ dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
if (!producer || !producer.hasTensorSemantics())
continue;
Optional<SmallVector<Value>> fusedOpResults =
@@ -1322,9 +1337,9 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
rewriter.startRootUpdate(op);
bool modifiedOutput = false;
Location loc = op.getLoc();
- for (OpOperand &opOperand : op.getOutputOpOperands()) {
- if (!op.payloadUsesValueFromOpOperand(&opOperand)) {
- Value operandVal = opOperand.get();
+ for (OpOperand *opOperand : op.getOutputOperands()) {
+ if (!op.payloadUsesValueFromOperand(opOperand)) {
+ Value operandVal = opOperand->get();
auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
if (!operandType)
continue;
@@ -1344,7 +1359,7 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
Value initTensor = rewriter.create<InitTensorOp>(
loc, dynamicDims, operandType.getShape(),
operandType.getElementType());
- op->setOperand(opOperand.getOperandNumber(), initTensor);
+ op->setOperand(opOperand->getOperandNumber(), initTensor);
}
}
if (!modifiedOutput) {
More information about the Mlir-commits
mailing list