[Mlir-commits] [mlir] 2441c07 - [mlir][bufferization] Support multiple leaves in EmptyTensorElimination
Matthias Springer
llvmlistbot at llvm.org
Fri Feb 10 00:38:55 PST 2023
Author: Matthias Springer
Date: 2023-02-10T09:38:47+01:00
New Revision: 2441c0730603eb6a543dae56c14c3e1ccb08fb55
URL: https://github.com/llvm/llvm-project/commit/2441c0730603eb6a543dae56c14c3e1ccb08fb55
DIFF: https://github.com/llvm/llvm-project/commit/2441c0730603eb6a543dae56c14c3e1ccb08fb55.diff
LOG: [mlir][bufferization] Support multiple leaves in EmptyTensorElimination
Support cases where a source tensor can be traced back to multiple possible tensor.empty ops.
Differential Revision: https://reviews.llvm.org/D142130
Added:
Modified:
mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 1579cfd04c79b..5fc12573912f3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -103,64 +103,73 @@ findValidInsertionPoint(Operation *emptyTensorOp,
/// with the result of `rewriteFunc` if it is anchored on a matching
/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
/// chain, starting from the OpOperand and always following the aliasing
-/// OpOperand, that eventually ends at a single tensor::EmptyOp.
+/// OpOperand, that eventually ends at the tensor::EmptyOp.
+///
+/// E.g.:
+/// %0 = tensor.empty() : tensor<10xf32>
+/// %1 = linalg.fill ... outs(%0 : tensor<10xf32>)
+/// %2 = tensor.insert_slice %0 into %t ...
+///
+/// In the above example, the anchor is the source operand of the insert_slice
+/// op. When tracing back the reverse use-def chain, we end up at a
+/// tensor.empty op.
LogicalResult mlir::bufferization::eliminateEmptyTensors(
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
OpBuilder::InsertionGuard g(rewriter);
- WalkResult status = op->walk([&](Operation *op) {
+ op->walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
// Skip operands that do not bufferize inplace.
if (!state.isInPlace(operand))
continue;
// All values that are needed to create the replacement op.
SmallVector<Value> neededValues;
- // Is this a matching OpOperand?
+ // Is this an anchor?
if (!anchorMatchFunc(operand, neededValues))
continue;
- SetVector<Value> maybeEmptyTensor = state.findValueInReverseUseDefChain(
- operand.get(), /*condition=*/[&](Value val) { return false; },
- /*followEquivalentOnly=*/true);
- // Replace only if the reverse use-def chain ends at exactly one
- // tensor::EmptyOp.
- if (maybeEmptyTensor.size() != 1 ||
- !maybeEmptyTensor.front().getDefiningOp<tensor::EmptyOp>())
- continue;
- Value emptyTensor = maybeEmptyTensor.front();
+ // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
+ // equivalent tensors. I.e., stop when there are ops such as extract_slice
+ // on the path.
+ SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
+ operand.get(), /*condition=*/
+ [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
+ /*followEquivalentOnly=*/true, /*alwaysIncludeLeaves=*/false);
- // Replace only if the types match.
- // TODO: This could be extended to support IR such as:
- // %0 = tensor.empty() : tensor<128xf32>
- // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
- // %2 = tensor.expand_shape %1 ...
- // %3 = tensor.insert_slice %2 into ...
- if (emptyTensor.getType() != operand.get().getType())
- continue;
+ for (Value v : emptyTensors) {
+ Operation *emptyTensorOp = v.getDefiningOp();
- // Find a suitable insertion point.
- Operation *insertionPoint =
- findValidInsertionPoint(emptyTensor.getDefiningOp(), neededValues);
- if (!insertionPoint)
- continue;
+ // Replace only if the types match. We do not support slices or casts.
+ // TODO: This could be extended to support IR such as:
+ // %0 = tensor.empty() : tensor<128xf32>
+ // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
+ // %2 = tensor.expand_shape %1 ...
+ // %3 = tensor.insert_slice %2 into ...
+ if (v.getType() != operand.get().getType())
+ continue;
- // Create a replacement for the tensor::EmptyOp.
- rewriter.setInsertionPoint(insertionPoint);
- Value replacement = rewriteFunc(rewriter, emptyTensor.getLoc(), operand);
- if (!replacement)
- continue;
+ // Find a suitable insertion point. If no suitable insertion point for
+ // the replacement can be found, skip this replacement.
+ Operation *insertionPoint =
+ findValidInsertionPoint(emptyTensorOp, neededValues);
+ if (!insertionPoint)
+ continue;
- // Replace the tensor::EmptyOp.
- rewriter.replaceOp(emptyTensor.getDefiningOp(), replacement);
- state.resetCache();
- }
+ rewriter.setInsertionPoint(insertionPoint);
+ Value replacement =
+ rewriteFunc(rewriter, emptyTensorOp->getLoc(), operand);
+ if (!replacement)
+ continue;
- // Advance to the next operation.
- return WalkResult::advance();
+ // Replace the tensor::EmptyOp.
+ rewriter.replaceOp(emptyTensorOp, replacement);
+ state.resetCache();
+ }
+ }
});
- return failure(status.wasInterrupted());
+ return success();
}
/// Try to eliminate tensor::EmptyOps inside `op`. An tensor::EmptyOp can be
@@ -253,6 +262,7 @@ struct EmptyTensorElimination
void EmptyTensorElimination::runOnOperation() {
Operation *op = getOperation();
OneShotBufferizationOptions options;
+ options.allowReturnAllocs = true;
OneShotAnalysisState state(op, options);
if (failed(analyzeOp(op, state))) {
signalPassFailure();
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 1c0860ffafbef..753840572f4b3 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -169,3 +169,38 @@ func.func @parallel_insert_slice(
return %r1: tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @eleminate_multiple_ops(
+// CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref<?xf32>
+// CHECK-SAME: %[[sz:[0-9a-zA-Z]*]]: index
+func.func @eleminate_multiple_ops(%t: tensor<?xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>}, %sz: index, %c: i1)
+ -> (tensor<?xf32>)
+{
+ %cst1 = arith.constant 0.0: f32
+ %cst2 = arith.constant 1.0: f32
+
+ // CHECK: %[[r:.*]] = scf.if %{{.*}} -> (memref
+ %if = scf.if %c -> tensor<?xf32> {
+ // CHECK: %[[T_SUBVIEW_1:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
+ %a1 = tensor.empty(%sz) : tensor<?xf32>
+ // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_1]] : memref<?xf32
+ %f1 = linalg.fill ins(%cst1 : f32) outs(%a1 : tensor<?xf32>) -> tensor<?xf32>
+ // CHECK: scf.yield %[[T_SUBVIEW_1]]
+ scf.yield %f1 : tensor<?xf32>
+ } else {
+ // CHECK: %[[T_SUBVIEW_2:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
+ %a2 = tensor.empty(%sz) : tensor<?xf32>
+ // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_2]] : memref<?xf32
+ %f2 = linalg.fill ins(%cst2 : f32) outs(%a2 : tensor<?xf32>) -> tensor<?xf32>
+ // CHECK: scf.yield %[[T_SUBVIEW_2]]
+ scf.yield %f2 : tensor<?xf32>
+ }
+
+ // Self-copy could canonicalize away later.
+ // CHECK: %[[T_SUBVIEW_3:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
+ // CHECK: memref.copy %[[r]], %[[T_SUBVIEW_3]]
+ %r1 = tensor.insert_slice %if into %t[42][%sz][1]: tensor<?xf32> into tensor<?xf32>
+ return %r1: tensor<?xf32>
+}
More information about the Mlir-commits
mailing list