[Mlir-commits] [mlir] 0fcaca2 - [mlir][bufferization] `MaterializeInDestinationOp`: Support memref destinations (#68074)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 6 02:57:15 PDT 2023
Author: Matthias Springer
Date: 2023-10-06T11:57:10+02:00
New Revision: 0fcaca2feaa973afa9275c0cb931775f5e3bde4c
URL: https://github.com/llvm/llvm-project/commit/0fcaca2feaa973afa9275c0cb931775f5e3bde4c
DIFF: https://github.com/llvm/llvm-project/commit/0fcaca2feaa973afa9275c0cb931775f5e3bde4c.diff
LOG: [mlir][bufferization] `MaterializeInDestinationOp`: Support memref destinations (#68074)
Extend `bufferization.materialize_in_destination` to support memref
destinations. This op can now be used to indicate that a tensor
computation should materialize in a given buffer (that may have been
allocated by another component/runtime). The op still participates in
"empty tensor elimination".
Example:
```mlir
func.func @test(%out: memref<10xf32>) {
%t = tensor.empty() : tensor<10xf32>
%c = linalg.generic ... outs(%t: tensor<10xf32>) -> tensor<10xf32>
bufferization.materialize_in_destination %c in restrict writable %out : (tensor<10xf32>, memref<10xf32>) -> ()
return
}
```
After "empty tensor elimination", the above IR can bufferize without an
allocation:
```mlir
func.func @test(%out: memref<10xf32>) {
linalg.generic ... outs(%out: memref<10xf32>)
return
}
```
This change also clarifies the meaning of the `restrict` unit attribute
on `bufferization.to_tensor` ops.
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
mlir/test/Dialect/Bufferization/invalid.mlir
mlir/test/Dialect/Bufferization/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index db93a51775ffcd7..09ce2981d382680 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -216,33 +216,58 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
def Bufferization_MaterializeInDestinationOp
: Bufferization_Op<"materialize_in_destination",
- [BufferizableOpInterface, SameOperandsAndResultType,
- DestinationStyleOpInterface,
+ [AllShapesMatch<["source", "dest"]>,
+ AllElementTypesMatch<["source", "dest"]>,
+ BufferizableOpInterface, DestinationStyleOpInterface,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
- "buildSubsetExtraction", "isEquivalentSubset"]>]> {
+ "buildSubsetExtraction", "isEquivalentSubset"]>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface, ["getEffects"]>]> {
let summary = "copy a tensor";
let description = [{
This op indicates that the data of the `source` tensor should materialize
- in the future buffer of the `dest` tensors. Both tensors must have the same
- shape and element type at runtime.
+ in `dest`, which can be a tensor or a memref. In case of a tensor, `source`
+ should materialize in the future buffer of `dest` and a the updated
+ destination tensor is returned. In case of a memref, `source` should
+ materialize in `dest`, which is already a buffer. The op has no results in
+ that case.
+
+ `source`, `dest` and `result` (if present) must have the same shape and
+ element type. If the op has a result, the types of `result` and `dest` must
+ match exactly (e.g., including any tensor encodings).
By default, this op bufferizes to a memcpy from the future buffer of the
- `source` tensor to the future buffer of the `dest` tensor. However,
- transformations such as "empty tensor elimination" may rewrite IR such that
- a computation is performed directly in the future buffer of the `dest`
- tensor and no memcpy is needed.
-
- Note: "tensor.insert_slice" could be used for the same purpose, but since
- tensor dialect ops only indicate *what* should be computed but not *where*,
- it could fold away, causing the computation to materialize in a
diff erent
- buffer.
+ `source` tensor to the future buffer of the `dest` tensor or to the `dest`
+ buffer. However, transformations such as "empty tensor elimination" may
+ rewrite IR such that a computation is performed directly in `dest` and no
+ memcpy is needed.
+
+ If `dest` is a buffer, the `restrict` and `writable` attributes must be
+ specified. These attributes have the same meaning as the respective
+ attributes of `bufferization.to_tensor`. `writable` indicates that the
+ `dest` buffer is considered writable. It does not make sense to materialize
+ a computation in a read-only buffer, so `writable` is required. `restrict`
+ indicates that this op is the only way for the tensor IR to access `dest`
+ (or an alias thereof). E.g., there must be no other `to_tensor` ops with
+ `dest` or with an alias of `dest`. Such IR is not supported by
+ One-Shot Bufferize.
+
+ Note: `restrict` and `writable` could be removed from this op because they
+ must always be set for memref destinations. This op has these attributes to
+ make clear the requirements on the `dest` operand in the op assembly format.
+ Moreover, these requirements may be relaxed at some point in the future.
+
+ Note: If `dest` is a tensor, `tensor.insert_slice` could be used for the
+ same purpose, but since tensor dialect ops only indicate *what* should be
+ computed but not *where*, it could fold away, causing the computation to
+ materialize in a
diff erent buffer.
}];
- let arguments = (ins AnyTensor:$source, AnyTensor:$dest);
- let results = (outs AnyTensor:$result);
+ let arguments = (ins AnyTensor:$source, AnyShaped:$dest,
+ UnitAttr:$restrict, UnitAttr:$writable);
+ let results = (outs Optional<AnyTensor>:$result);
let extraClassDeclaration = [{
LogicalResult bufferize(RewriterBase &rewriter,
@@ -264,10 +289,23 @@ def Bufferization_MaterializeInDestinationOp
return ::llvm::cast<RankedTensorType>(getResult().getType());
}
- MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
+ MutableOperandRange getDpsInitsMutable();
+
+ bool isWritable(Value value, const AnalysisState &state);
}];
- let assemblyFormat = "$source `in` $dest attr-dict `:` type($source)";
+ let builders = [
+ // Builder that materializes a source tensor in a tensor destination.
+ // Asserts that `dest` has tensor type. Infers the result type of this op
+ // from the destination tensor.
+ OpBuilder<(ins "Value":$source, "Value":$dest)>
+ ];
+
+ let assemblyFormat = [{
+ $source `in` (`restrict` $restrict^)? (`writable` $writable^)? $dest
+ attr-dict `:` functional-type(operands, results)
+ }];
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -361,10 +399,15 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
thereof) will bufferize out-of-place to prevent emitting any writes to
`memref` during bufferization.
- If the given memref does not alias with any other memref passed to another
- `to_tensor` op, the `restrict` unit attribute can be set. Only such
- operations are supported by One-Shot Bufferize. (Otherwise, potential memref
- aliasing relationships would have to be captured in One-Shot Bufferize.)
+ The `restrict` unit attribute (similar to the C `restrict` keyword)
+ indicates that the produced tensor result is the only way for the tensor
+ IR to gain access to the `memref` operand (or an alias thereof). E.g.,
+ there must be no other `to_tensor` op with the same or with an aliasing
+ `memref` operand.
+
+ Note: Only `to_tensor` ops with the `restrict` unit attribute are supported
+ by One-Shot Bufferize. Other IR is rejected. (To support `to_tensor`
+ without `restrict`, One-Shot Bufferize would have to analyze memref IR.)
Example:
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 01cbacc96fd42d2..1c33f444d15850c 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -542,25 +542,40 @@ bool MaterializeInDestinationOp::bufferizesToMemoryRead(
bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
OpOperand &opOperand, const AnalysisState &state) {
- return &opOperand == &getDestMutable();
+ if (&opOperand == &getDestMutable()) {
+ assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
+ return true;
+ }
+ return false;
}
AliasingValueList
MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
const AnalysisState &state) {
- if (&opOperand == &getDestMutable())
+ if (&opOperand == &getDestMutable()) {
+ assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
+ }
return {};
}
LogicalResult
MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
const BufferizationOptions &options) {
- FailureOr<Value> buffer = getBuffer(rewriter, getDest(), options);
- if (failed(buffer))
- return failure();
- rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), *buffer);
- replaceOpWithBufferizedValues(rewriter, getOperation(), *buffer);
+ bool tensorDest = isa<TensorType>(getDest().getType());
+ Value buffer;
+ if (tensorDest) {
+ FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
+ if (failed(maybeBuffer))
+ return failure();
+ buffer = *maybeBuffer;
+ } else {
+ assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
+ buffer = getDest();
+ }
+ rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), buffer);
+ replaceOpWithBufferizedValues(rewriter, getOperation(),
+ tensorDest ? ValueRange(buffer) : ValueRange());
return success();
}
@@ -573,15 +588,29 @@ bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
LogicalResult MaterializeInDestinationOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
- reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
- reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
+ if (getOperation()->getNumResults() == 1) {
+ assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
+ reifiedReturnShapes.resize(1,
+ SmallVector<OpFoldResult>(getType().getRank()));
+ reifiedReturnShapes[0] =
+ tensor::getMixedSizes(builder, getLoc(), getDest());
+ }
return success();
}
Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
Location loc) {
- // The subset is the entire destination tensor.
- return getDest();
+ if (isa<TensorType>(getDest().getType())) {
+ // The subset is the entire destination tensor.
+ return getDest();
+ }
+
+ // Build a bufferization.to_tensor op.
+ assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
+ assert(getRestrict() &&
+ "expected that ops with memrefs dest have 'restrict'");
+ return builder.create<ToTensorOp>(loc, getDest(), getRestrict(),
+ getWritable());
}
bool MaterializeInDestinationOp::isEquivalentSubset(
@@ -598,6 +627,51 @@ OpOperand &MaterializeInDestinationOp::getSourceOperand() {
return getOperation()->getOpOperand(0) /*source*/;
}
+LogicalResult MaterializeInDestinationOp::verify() {
+ if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
+ return emitOpError("'dest' must be a tensor or a memref");
+ if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
+ if (getOperation()->getNumResults() != 1)
+ return emitOpError("tensor 'dest' implies exactly one tensor result");
+ if (destType != getResult().getType())
+ return emitOpError("result and 'dest' types must match");
+ }
+ if (isa<BaseMemRefType>(getDest().getType()) &&
+ getOperation()->getNumResults() != 0)
+ return emitOpError("memref 'dest' implies zero results");
+ if (getRestrict() != isa<BaseMemRefType>(getDest().getType()))
+ return emitOpError("'restrict' must be specified if and only if the "
+ "destination is of memref type");
+ if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
+ return emitOpError("'writable' must be specified if and only if the "
+ "destination is of memref type");
+ return success();
+}
+
+void MaterializeInDestinationOp::build(OpBuilder &builder,
+ OperationState &state, Value source,
+ Value dest) {
+ assert(isa<TensorType>(dest.getType()) && "expected tensor type");
+ build(builder, state, /*result=*/dest.getType(), source, dest);
+}
+
+bool MaterializeInDestinationOp::isWritable(Value value,
+ const AnalysisState &state) {
+ return isa<TensorType>(getDest().getType()) ? true : getWritable();
+}
+
+MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
+ return getDestMutable();
+}
+
+void MaterializeInDestinationOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (isa<BaseMemRefType>(getDest().getType()))
+ effects.emplace_back(MemoryEffects::Write::get(), getDest(),
+ SideEffects::DefaultResource::get());
+}
+
//===----------------------------------------------------------------------===//
// ToTensorOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index a74a3c2c500406f..e6d80a39650ccf0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -248,8 +248,10 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
LinalgPaddingOptions::CopyBackOp::
BufferizationMaterializeInDestination) {
replacements.push_back(
- rewriter.create<bufferization::MaterializeInDestinationOp>(
- loc, std::get<0>(it), std::get<1>(it).get()));
+ rewriter
+ .create<bufferization::MaterializeInDestinationOp>(
+ loc, std::get<0>(it), std::get<1>(it).get())
+ ->getResult(0));
} else {
llvm_unreachable("unsupported copy back op");
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index a2fbb06d179ebda..c3e44c426797f39 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -172,7 +172,7 @@ func.func @materialize_in_destination_aliasing(%t: tensor<?xf32>, %p1: index, %p
%dest = tensor.extract_slice %t[%p2][5][1] : tensor<?xf32> to tensor<5xf32>
// CHECK: bufferization.materialize_in_destination
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
- %r = bufferization.materialize_in_destination %src in %dest : tensor<5xf32>
+ %r = bufferization.materialize_in_destination %src in %dest : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
return %r : tensor<5xf32>
}
@@ -183,6 +183,6 @@ func.func @materialize_in_destination(%t: tensor<?xf32>, %sz: index) -> tensor<?
%buffer = tensor.empty(%sz) : tensor<?xf32>
// CHECK: bufferization.materialize_in_destination
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
- %r = bufferization.materialize_in_destination %buffer in %buffer : tensor<?xf32>
+ %r = bufferization.materialize_in_destination %buffer in %buffer : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %r : tensor<?xf32>
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index b68682a459ed2c2..99b974b9ef3c67e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -301,12 +301,25 @@ func.func @regression_multiple_insertion_points(%t1: tensor<?x?xf32>) -> tensor<
func.func @materialize_in_destination(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> {
%0 = tensor.empty() : tensor<5xf32>
%filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
- %1 = bufferization.materialize_in_destination %filled in %t : tensor<5xf32>
+ %1 = bufferization.materialize_in_destination %filled in %t : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
return %1 : tensor<5xf32>
}
// -----
+// CHECK-LABEL: func @materialize_in_destination_buffer(
+// CHECK-SAME: %[[m:.*]]: memref<5xf32>,
+// CHECK-NEXT: linalg.fill {{.*}} outs(%[[m]]
+// CHECK-NEXT: return
+func.func @materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32) {
+ %0 = tensor.empty() : tensor<5xf32>
+ %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+ bufferization.materialize_in_destination %filled in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @linalg_copy(
// CHECK-SAME: %[[m:.*]]: memref<5xf32, strided<[?], offset: ?>>,
// CHECK: linalg.fill {{.*}} outs(%[[m]]
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index 3f468750cc28405..272423de5564b09 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -218,6 +218,20 @@ func.func @tensor_copy(%arg0: tensor<5xf32>) -> tensor<5xf32> {
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
// CHECK: return %[[r]]
%dest = bufferization.alloc_tensor() : tensor<5xf32>
- %0 = bufferization.materialize_in_destination %arg0 in %dest : tensor<5xf32>
+ %0 = bufferization.materialize_in_destination %arg0 in %dest
+ : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
return %0 : tensor<5xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @materialize_in_destination_buffer(
+// CHECK-SAME: %[[t:.*]]: tensor<5xf32>, %[[m:.*]]: memref<5xf32>)
+// CHECK: %[[b:.*]] = bufferization.to_memref %[[t]] : memref<5xf32, strided<[?], offset: ?>>
+// CHECK: memref.copy %[[b]], %[[m]]
+func.func @materialize_in_destination_buffer(%t: tensor<5xf32>, %m: memref<5xf32>) {
+ bufferization.materialize_in_destination %t in restrict writable %m
+ : (tensor<5xf32>, memref<5xf32>) -> ()
+ return
+}
+
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 8004ec632453e8c..ce56f89c1f1bbe6 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -66,10 +66,58 @@ func.func @invalid_writable_on_op() {
// -----
-// expected-note @below{{prior use here}}
func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
- // expected-error @below{{expects
diff erent type than prior uses: 'tensor<?xf32>' vs 'tensor<5xf32>'}}
- bufferization.materialize_in_destination %arg0 in %arg1 : tensor<?xf32>
+ // expected-error @below{{failed to verify that all of {source, dest} have same shape}}
+ bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_dest_type(%arg0: tensor<5xf32>, %arg1: vector<5xf32>) {
+ // expected-error @below{{'dest' must be a tensor or a memref}}
+ bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5xf32>, vector<5xf32>) -> ()
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_restrict_missing(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
+ // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
+ bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, memref<?xf32>) -> ()
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_result(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
+ // expected-error @below{{memref 'dest' implies zero results}}
+ bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, memref<?xf32>) -> (tensor<?xf32>)
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_result_missing(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+ // expected-error @below{{tensor 'dest' implies exactly one tensor result}}
+ bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> ()
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+ // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
+ bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+ // expected-error @below{{'writable' must be specified if and only if the destination is of memref type}}
+ bufferization.materialize_in_destination %arg0 in writable %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_result_shape(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+ // expected-error @below{{result and 'dest' types must match}}
+ bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<6xf32>)
}
// -----
diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir
index dc53e535bfe0d57..d4bda0632189d41 100644
--- a/mlir/test/Dialect/Bufferization/ops.mlir
+++ b/mlir/test/Dialect/Bufferization/ops.mlir
@@ -59,10 +59,12 @@ func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) {
}
// CHECK-LABEL: func @test_materialize_in_destination_op
-func.func @test_materialize_in_destination_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>)
+func.func @test_materialize_in_destination_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>)
-> tensor<?xf32> {
- // CHECK: bufferization.materialize_in_destination {{.*}} : tensor<?xf32>
- %1 = bufferization.materialize_in_destination %arg0 in %arg1 : tensor<?xf32>
+ // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+ %1 = bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+ // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, memref<?xf32, 3>) -> ()
+ bufferization.materialize_in_destination %arg0 in restrict writable %arg2 : (tensor<?xf32>, memref<?xf32, 3>) -> ()
return %1 : tensor<?xf32>
}
More information about the Mlir-commits
mailing list