[Mlir-commits] [mlir] bcd6424 - [mlir][Linalg] Fix linalg on tensor fusion
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Mar 22 06:30:02 PDT 2021
Author: Nicolas Vasilache
Date: 2021-03-22T13:29:40Z
New Revision: bcd6424f9b693af57b29a0f03c52d6991be35d41
URL: https://github.com/llvm/llvm-project/commit/bcd6424f9b693af57b29a0f03c52d6991be35d41
DIFF: https://github.com/llvm/llvm-project/commit/bcd6424f9b693af57b29a0f03c52d6991be35d41.diff
LOG: [mlir][Linalg] Fix linalg on tensor fusion
- Drop unnecessary occurrences of rewriter.eraseOp: dead linalg ops on tensors should be cleaned up by DCE.
- reimplement the part of Linalg on fusion that constructs the body and block arguments: the previous implementation had too much magic. Instead this spells out all cases explicitly and asserts / introduces TODOs for incorrect cases.
As a consequence, we can use the default traversal order for this pattern.
Differential Revision: https://reviews.llvm.org/D99070
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/Dialect/Linalg/fusion-tensor.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 1e94dfd3ef94..a6d0fd5dd7b7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -37,6 +37,11 @@ static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer,
if (producer.getNumParallelLoops() != producer.getNumLoops())
return false;
+ // Only allow fusing the producer of an input operand for now.
+ // TODO: allow fusing the producer of an output operand.
+ if (consumerIdx >= consumer.getNumInputs())
+ return false;
+
// Get the consumer index map. The number of results of the consumer index
// map must match the number of loops of the producer.
AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
@@ -120,60 +125,86 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
isa<IndexedGenericOp>(consumer.getOperation()))
? std::max(producer.getNumLoops(), consumer.getNumLoops())
: 0;
- // Firstly, add all the indices to the block arguments.
+
+ // 0. Firstly, add all the indices to the block arguments.
for (unsigned i = 0, e = numFusedOpIndices; i < e; ++i)
fusedBlock->addArgument(rewriter.getIndexType());
- // Map the arguments for the unmodified args from the consumer.
- for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
- if (consumerArg.index() == consumerIdx + numConsumerIndices) {
- // Map the arguments for the args from the producer.
- for (auto producerArg :
- llvm::enumerate(producerBlock.getArguments().take_front(
- producer.getNumInputs() + numProducerIndices))) {
- // If producer is an indexed_generic op, map the indices from consumer
- // loop to producer loop (because the fusedOp is built based on
- // consumer's perspective).
- if (producerArg.index() < numProducerIndices) {
- auto newIndex = rewriter.create<mlir::AffineApplyOp>(
- producer.getLoc(),
- consumerToProducerLoopsMap.getSubMap(producerArg.index()),
- fusedBlock->getArguments().take_front(numFusedOpIndices));
- mapper.map(producerArg.value(), newIndex);
- } else {
- mapper.map(producerArg.value(),
- fusedBlock->addArgument(producerArg.value().getType()));
- }
- }
- continue;
- }
-
- // If consumer is an indexed_generic op, map the indices to the block
- // arguments directly. Otherwise, add the same type of argument and map to
- // it.
- if (consumerArg.index() < numConsumerIndices) {
- mapper.map(consumerArg.value(),
- fusedBlock->getArgument(consumerArg.index()));
- } else {
- mapper.map(consumerArg.value(),
- fusedBlock->addArgument(consumerArg.value().getType()));
- }
+ // 1. Map consumer indices to fusedBlock indices 1-1.
+ mapper.map(consumerBlock.getArguments().take_front(numConsumerIndices),
+ fusedBlock->getArguments().take_front(numConsumerIndices));
+ // 2. Embed producer indices into fusedBlock index space 1-1.
+ for (auto it :
+ llvm::zip(producerBlock.getArguments().take_front(numProducerIndices),
+ fusedBlock->getArguments().take_front(numProducerIndices))) {
+ auto newIndex = rewriter.create<mlir::AffineApplyOp>(
+ producer.getLoc(),
+ consumerToProducerLoopsMap.getSubMap(std::get<0>(it).getArgNumber()),
+ fusedBlock->getArguments().take_front(numFusedOpIndices));
+ mapper.map(std::get<0>(it), newIndex);
}
-
- // Add operations from producer (except the yield operation) to the fused
+ // TODO: allow fusing the producer of an output operand.
+ assert(consumerIdx < consumer.getNumInputs() &&
+ "expected producer of input operand");
+ // 3. Consumer input operands up to consumerIdx (exclusive).
+ for (BlockArgument bbArg : consumerBlock.getArguments()
+ .drop_front(numConsumerIndices)
+ .take_front(consumerIdx)) // input assumption.
+ mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
+
+ // Replacing consumerIdx requires getting the cloned, yielded, value from
+ // the (cloned) producer block. This happens in step 9.
+
+ // 4. Splice in producer's input operands.
+ for (BlockArgument bbArg : producerBlock.getArguments()
+ .drop_front(numProducerIndices)
+ .take_front(producer.getNumInputs()))
+ mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
+ // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
+ for (BlockArgument bbArg : consumerBlock.getArguments()
+ .drop_front(numConsumerIndices)
+ .take_front(consumer.getNumInputs())
+ .drop_front(consumerIdx + 1))
+ mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
+ // 6. All of consumer's output operands.
+ for (BlockArgument bbArg :
+ consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
+ mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
+ // 7. All of producer's output operands except the one fused.
+ // TODO: allow fusion of multi-result producers.
+ assert(producer->getNumResults() == 1 && "expected single result producer");
+
+ // 8. Clone operations from producer (except the yield operation) to the fused
// op.
- for (auto &op : producerBlock.getOperations()) {
- if (auto yieldOp = dyn_cast<linalg::YieldOp>(op)) {
- // Lookup the value the yield operation is mapped to.
- Value yieldVal = yieldOp.getOperand(0);
- if (Value clonedVal = mapper.lookupOrNull(yieldVal))
- mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices),
- clonedVal);
- continue;
- }
+ for (auto &op : producerBlock.without_terminator())
rewriter.clone(op, mapper);
+ // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
+ // forward the yield operand.
+ auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
+ // TODO: allow fusion of multi-result producers.
+ assert(producer->getNumResults() == 1 && "expected single result producer");
+ unsigned producerResultNumber = 0;
+ Value replacement =
+ mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
+ // Sanity checks, if replacement is not already in the mapper then it must be
+ // produced outside.
+ if (replacement == yieldOp.getOperand(producerResultNumber)) {
+ if (auto bb = replacement.dyn_cast<BlockArgument>())
+ assert(bb.getOwner() != &producerBlock &&
+ "yielded block argument must have been mapped");
+ else
+ assert(!producer->isAncestor(replacement.getDefiningOp()) &&
+ "yielded value must have been mapped");
}
+ mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices),
+ replacement);
+ // 10. Clone operations from the consumer to the fused op.
for (auto &op : consumerBlock.getOperations())
rewriter.clone(op, mapper);
+
+ // Sanity checks.
+ assert(fusedBlock->getNumArguments() ==
+ fusedOp->getNumOperands() + numFusedOpIndices &&
+ "Ill-formed LinalgOp region");
}
static Optional<SmallVector<Value, 1>>
@@ -856,8 +887,6 @@ struct FoldProducerReshapeOpByLinearization
op->setOperands(fusedOperands);
op.indexing_mapsAttr(rewriter.getAffineMapArrayAttr(fusedIndexMaps));
rewriter.finalizeRootUpdate(op);
- if (reshapeOp.use_empty())
- rewriter.eraseOp(reshapeOp);
return success();
}
return failure();
@@ -897,8 +926,6 @@ struct FoldWithProducerReshapeOpByExpansion
if (!replacementValues)
return failure();
rewriter.replaceOp(genericOp, replacementValues.getValue());
- if (reshapeOp.use_empty())
- rewriter.eraseOp(reshapeOp);
return success();
}
return failure();
@@ -963,8 +990,6 @@ struct FoldConsumerReshapeOpByLinearization
rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
fusedRegion.begin());
rewriter.replaceOp(reshapeOp, fusedOp->getResults());
- if (producer.use_empty())
- rewriter.eraseOp(producer);
return success();
}
};
@@ -995,8 +1020,6 @@ struct FoldReshapeWithGenericOpByExpansion
if (!replacementValues)
return failure();
rewriter.replaceOp(reshapeOp, replacementValues.getValue());
- if (producer.use_empty())
- rewriter.eraseOp(producer);
return success();
}
};
@@ -1057,8 +1080,6 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
rewriter.cloneRegionBefore(linalgOpRegion, fusedRegion,
fusedRegion.begin(), mapping);
rewriter.replaceOp(linalgOp, fusedOp->getResults());
- if (constantOp.use_empty())
- rewriter.eraseOp(constantOp);
return success();
}
return failure();
@@ -1092,15 +1113,14 @@ struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand &opOperand : op.getShapedOpOperands()) {
- Operation *producer = opOperand.get().getDefiningOp();
- if (!producer)
+ LinalgOp producerOp =
+ dyn_cast_or_null<LinalgOp>(opOperand.get().getDefiningOp());
+ if (!producerOp || !producerOp.hasTensorSemantics())
continue;
Optional<SmallVector<Value, 1>> fusedOpResults =
fuseTensorOps(rewriter, opOperand);
if (fusedOpResults) {
rewriter.replaceOp(op, *fusedOpResults);
- if (producer->use_empty())
- rewriter.eraseOp(producer);
return success();
}
}
@@ -1115,8 +1135,7 @@ struct FusionOfTensorOpsPass
Operation *op = getOperation();
OwningRewritePatternList patterns(op->getContext());
populateLinalgTensorOpsFusionPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
- /*useTopDown=*/false);
+ (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index a4071897b4d8..13109bd98c19 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -578,3 +578,41 @@ func @consumer_with_reduction(%arg0: tensor<1x10xf32>,
// CHECK: %[[T4:.+]] = addf %[[T3]], %[[T2]] : f32
// CHECK: linalg.yield %[[T4]]
// CHECK: return %[[RES]]
+
+// -----
+
+// CHECK-LABEL: func @sigmoid_dynamic_dim(
+// CHECK: %[[RES:.*]] = linalg.generic
+// CHECK-NOT: linalg.generic
+// CHECK: return %[[RES]]
+func @sigmoid_dynamic_dim(%0: tensor<?x1xf32>) -> tensor<?x1xf32> {
+ %cp5 = constant 5.000000e-01 : f32
+ %c0 = constant 0 : index
+ %shape = shape.shape_of %0 : tensor<?x1xf32> -> tensor<?xindex>
+ %extend = shape.to_extent_tensor %shape : tensor<?xindex> -> tensor<2xindex>
+ %extracted = tensor.extract %extend[%c0] : tensor<2xindex>
+ %init0 = linalg.init_tensor [%extracted, 1] : tensor<?x1xf32>
+ %1 = linalg.generic {indexing_maps = [
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ }
+ outs(%init0 : tensor<?x1xf32>) {
+ ^bb0(%a: f32): // no predecessors
+ linalg.yield %cp5 : f32
+ } -> tensor<?x1xf32>
+ %d0 = memref.dim %0, %c0 : tensor<?x1xf32>
+ %init1 = linalg.init_tensor [%d0, 1] : tensor<?x1xf32>
+ %2 = linalg.generic {indexing_maps = [
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ }
+ ins(%0, %1 : tensor<?x1xf32>, tensor<?x1xf32>)
+ outs(%init1 : tensor<?x1xf32>) {
+ ^bb0(%a: f32, %b: f32, %c: f32): // no predecessors
+ %m = mulf %a, %b : f32
+ linalg.yield %m : f32
+ } -> tensor<?x1xf32>
+ return %2 : tensor<?x1xf32>
+}
More information about the Mlir-commits
mailing list