[Mlir-commits] [mlir] 8ddd98f - [mlir][linalg] Return newly created ops from bufferize_to_allocation
Matthias Springer
llvmlistbot at llvm.org
Tue Jul 11 07:34:14 PDT 2023
Author: Matthias Springer
Date: 2023-07-11T16:34:02+02:00
New Revision: 8ddd98f83136646bd4f88fe35919b0af03337b75
URL: https://github.com/llvm/llvm-project/commit/8ddd98f83136646bd4f88fe35919b0af03337b75
DIFF: https://github.com/llvm/llvm-project/commit/8ddd98f83136646bd4f88fe35919b0af03337b75.diff
LOG: [mlir][linalg] Return newly created ops from bufferize_to_allocation
Return all ops that were generated as part of the bufferization, so that users do not have to match them in the enclosing op.
Differential Revision: https://reviews.llvm.org/D154966
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir
mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 2979a8018cdf3a..4a143a158867cb 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -87,8 +87,8 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
result in a new allocation. It replaces all original uses of the target
result with the newly allocated buffer, wrapped in a
`bufferization.to_tensor` op. It returns a handle to the newly allocated
- buffer. Furthermore, it returns a handle to the result of the `to_tensor`
- op.
+ buffer. Furthermore, it returns a handle that is mapped to all newly created
+ ops.
Only bufferizable ops are that bufferize to a memory write or have an
aliasing OpOperand (and do not themselves bufferize to an allocation) are
@@ -121,12 +121,13 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
#### Return modes
This operation consumes the `target` handle and produces the
- `allocated_buffer` handle. It always succeeds.
+ `allocated_buffer` and `new_ops` handles. It always succeeds.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
OptionalAttr<AnyAttr>:$memory_space);
- let results = (outs Transform_AnyValue:$allocated_buffer);
+ let results = (outs Transform_AnyValue:$allocated_buffer,
+ Transform_AnyOpType:$new_ops);
let assemblyFormat = "$target attr-dict `:` type($target)";
let builders = [
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index fdabbcc05181dd..6625ef553eba21 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -444,7 +444,7 @@ class RewriterBase : public OpBuilder {
/// struct can be used as a base to create listener chains, so that multiple
/// listeners can be notified of IR changes.
struct ForwardingListener : public RewriterBase::Listener {
- ForwardingListener(Listener *listener) : listener(listener) {}
+ ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
void notifyOperationInserted(Operation *op) override {
listener->notifyOperationInserted(op);
@@ -453,26 +453,32 @@ class RewriterBase : public OpBuilder {
listener->notifyBlockCreated(block);
}
void notifyOperationModified(Operation *op) override {
- listener->notifyOperationModified(op);
+ if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ rewriteListener->notifyOperationModified(op);
}
void notifyOperationReplaced(Operation *op, Operation *newOp) override {
- listener->notifyOperationReplaced(op, newOp);
+ if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ rewriteListener->notifyOperationReplaced(op, newOp);
}
void notifyOperationReplaced(Operation *op,
ValueRange replacement) override {
- listener->notifyOperationReplaced(op, replacement);
+ if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ rewriteListener->notifyOperationReplaced(op, replacement);
}
void notifyOperationRemoved(Operation *op) override {
- listener->notifyOperationRemoved(op);
+ if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ rewriteListener->notifyOperationRemoved(op);
}
LogicalResult notifyMatchFailure(
Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override {
- return listener->notifyMatchFailure(loc, reasonCallback);
+ if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ return rewriteListener->notifyMatchFailure(loc, reasonCallback);
+ return failure();
}
private:
- Listener *listener;
+ OpBuilder::Listener *listener;
};
/// Move the blocks that belong to "region" before the given position in
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7ca6b272103b7a..5474377ee364b0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -177,8 +177,11 @@ void transform::BufferizeToAllocationOp::build(OpBuilder &b,
OperationState &result,
Value target,
Attribute memorySpace) {
+ SmallVector<Type> resultTypes;
+ resultTypes.push_back(b.getType<transform::AnyValueType>());
+ resultTypes.push_back(b.getType<transform::AnyOpType>());
return build(b, result,
- /*resultTypes=*/b.getType<transform::AnyValueType>(),
+ /*resultTypes=*/resultTypes,
/*target=*/target,
/*memorySpace=*/memorySpace);
}
@@ -187,15 +190,52 @@ void transform::BufferizeToAllocationOp::build(OpBuilder &b,
OperationState &result,
Value target,
int64_t memorySpace) {
+ SmallVector<Type> resultTypes;
+ resultTypes.push_back(b.getType<transform::AnyValueType>());
+ resultTypes.push_back(b.getType<transform::AnyOpType>());
return build(b, result,
- /*resultTypes=*/b.getType<transform::AnyValueType>(),
+ /*resultTypes=*/resultTypes,
/*target=*/target,
/*memorySpace=*/b.getI64IntegerAttr(memorySpace));
}
+namespace {
+class NewOpsListener : public RewriterBase::ForwardingListener {
+public:
+ using RewriterBase::ForwardingListener::ForwardingListener;
+
+ SmallVector<Operation *> getNewOps() const {
+ return SmallVector<Operation *>(newOps.begin(), newOps.end());
+ }
+
+private:
+ void notifyOperationInserted(Operation *op) override {
+ ForwardingListener::notifyOperationInserted(op);
+ auto inserted = newOps.insert(op);
+ (void)inserted;
+ assert(inserted.second && "expected newly created op");
+ }
+
+ void notifyOperationRemoved(Operation *op) override {
+ ForwardingListener::notifyOperationRemoved(op);
+ op->walk([&](Operation *op) { newOps.erase(op); });
+ }
+
+ DenseSet<Operation *> newOps;
+};
+} // namespace
+
DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
+ // Attach listener to keep track of newly created ops.
+ OpBuilder::Listener *previousListener = rewriter.getListener();
+ auto resetListener =
+ llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
+ NewOpsListener newOpsListener(previousListener);
+ rewriter.setListener(&newOpsListener);
+
+ // Bufferize ops.
Attribute memorySpace =
getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
SmallVector<Value> allocatedBuffers;
@@ -209,7 +249,10 @@ DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
}
allocatedBuffers.push_back(buffer);
}
+
+ // Set results.
results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
+ results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
return DiagnosedSilenceableFailure::success();
}
@@ -217,6 +260,7 @@ void transform::BufferizeToAllocationOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTarget(), effects);
producesHandle(getAllocatedBuffer(), effects);
+ producesHandle(getNewOps(), effects);
modifiesPayload(effects);
}
diff --git a/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir b/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir
index 55b0ff5016d177..2643935049c9f3 100644
--- a/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir
+++ b/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir
@@ -54,7 +54,7 @@ transform.sequence failures(propagate) {
padding_dimensions=[0, 1, 2],
pack_paddings=[1, 1, 1]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %buffer = transform.structured.bufferize_to_allocation %pad {memory_space = 3} : !transform.any_op
+ %buffer, %new_ops = transform.structured.bufferize_to_allocation %pad {memory_space = 3} : !transform.any_op
%2 = transform.bufferization.one_shot_bufferize %arg1 {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op
}
@@ -114,6 +114,6 @@ transform.sequence failures(propagate) {
transform.structured.masked_vectorize %pad vector_sizes [10, 12] : !transform.any_op
%vector_write = transform.structured.match ops{["vector.transfer_write"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%mask_op = transform.get_parent_op %vector_write {op_name = "vector.mask"} : (!transform.any_op) -> !transform.any_op
- %buffer = transform.structured.bufferize_to_allocation %mask_op {memory_space = 3} : !transform.any_op
+ %buffer, %new_ops = transform.structured.bufferize_to_allocation %mask_op {memory_space = 3} : !transform.any_op
%2 = transform.bufferization.one_shot_bufferize %arg1 {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
index e51c20334e9554..45efde3b077a44 100644
--- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
@@ -32,7 +32,17 @@ func.func @tensor_pad_constant(%t: tensor<?x10xindex>, %l2: index, %h1: index,
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.bufferize_to_allocation %0 : !transform.any_op
+ %2, %new = transform.structured.bufferize_to_allocation %0 : !transform.any_op
+
+ // Ensure that one linalg.fill was generated.
+ %fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
+ // expected-remark @below{{1}}
+ test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
+
+ // Ensure that one memref.tensor_store was generated.
+ %tensor_store = transform.select "memref.tensor_store" in %new : (!transform.any_op) -> !transform.any_op
+ // expected-remark @below{{1}}
+ test_print_number_of_associated_payload_ir_ops %tensor_store : !transform.any_op
}
// -----
@@ -57,7 +67,7 @@ func.func @tensor_pad_constant(%t: tensor<?x10xindex>, %l2: index, %h1: index,
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.bufferize_to_allocation %0 : !transform.any_op
+ %2, %new = transform.structured.bufferize_to_allocation %0 : !transform.any_op
// Make sure that One-Shot Bufferize can bufferize the rest.
%4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
}
@@ -81,7 +91,7 @@ func.func @tensor_insert(%t: tensor<?x10xindex>, %idx: index, %v: index) -> tens
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["tensor.insert"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
+ %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
// Make sure that One-Shot Bufferize can bufferize the rest.
%4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
}
@@ -104,7 +114,7 @@ func.func @tensor_insert_into_empty(%idx: index, %v: index) -> tensor<10xindex>
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["tensor.insert"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
+ %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
// Make sure that One-Shot Bufferize can bufferize the rest.
%4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
}
@@ -121,7 +131,7 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["tensor.extract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @below{{failed to bufferize operation}}
- %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
+ %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
}
// -----
@@ -142,5 +152,5 @@ func.func @vector_mask(%t: tensor<?xf32>, %val: vector<16xf32>, %idx: index, %m0
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["vector.mask"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
+ %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
}
More information about the Mlir-commits
mailing list