[Mlir-commits] [mlir] [mlir][bufferization] `MaterializeInDestinationOp`: Support memref destinations (PR #68074)

Matthias Springer llvmlistbot at llvm.org
Tue Oct 3 00:50:52 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/68074

Extend `bufferization.materialize_in_destination` to support memref destinations. This op can now be used to indicate that a tensor computation should materialize in a given buffer (that may have been allocated by another component/runtime). The op still participates in "empty tensor elimination".

Example:
```
func.func @test(%out: memref<10xf32>) {
  %t = tensor.empty() : tensor<10xf32>
  %c = linalg.generic ... outs(%t: tensor<10xf32>) -> tensor<10xf32>
  bufferization.materialize_in_destination %c in %out : (tensor<10xf32>, memref<10xf32>) -> ()
```

After "empty tensor elimination", the above IR can bufferize without an allocation. The "linalg.generic" is computed directly on %out.

>From ea217da17ee905198a9759f63022a18cf4646c93 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 3 Oct 2023 09:49:41 +0200
Subject: [PATCH] [mlir][bufferization] `MaterializeInDestinationOp`: Support
 memref destinations

Extend `bufferization.materialize_in_destination` to support memref destinations. This op can now be used to indicate that a tensor computation should materialize in a given buffer (that may have been allocated by another component/runtime). The op still participates in "empty tensor elimination".

Example:
```
func.func @test(%out: memref<10xf32>) {
  %t = tensor.empty() : tensor<10xf32>
  %c = linalg.generic ... outs(%t: tensor<10xf32>) -> tensor<10xf32>
  bufferization.materialize_in_destination %c in %out : (tensor<10xf32>, memref<10xf32>) -> ()
```

After "empty tensor elimination", the above IR can bufferize without an allocation. The "linalg.generic" is computed directly on %out.
---
 .../Bufferization/IR/BufferizationOps.td      | 72 +++++++++++----
 .../Bufferization/IR/BufferizationOps.cpp     | 92 ++++++++++++++++---
 .../Transforms/EmptyTensorElimination.cpp     |  1 +
 .../lib/Dialect/Linalg/Transforms/Padding.cpp |  6 +-
 .../one-shot-bufferize-analysis.mlir          |  4 +-
 ...ot-bufferize-empty-tensor-elimination.mlir | 15 ++-
 .../Transforms/one-shot-bufferize.mlir        | 15 ++-
 mlir/test/Dialect/Bufferization/invalid.mlir  | 47 +++++++++-
 mlir/test/Dialect/Bufferization/ops.mlir      |  8 +-
 9 files changed, 219 insertions(+), 41 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 9761ab12134ad28..68d64e685eeabcb 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -216,33 +216,56 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
 
 def Bufferization_MaterializeInDestinationOp
     : Bufferization_Op<"materialize_in_destination",
-        [BufferizableOpInterface, SameOperandsAndResultType,
-         DestinationStyleOpInterface,
+        [AllShapesMatch<["source", "dest"]>,
+         AllElementTypesMatch<["source", "dest"]>,
+         BufferizableOpInterface, DestinationStyleOpInterface,
          DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
          DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
             ["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
-             "buildSubsetExtraction", "isEquivalentSubset"]>]> {
+             "buildSubsetExtraction", "isEquivalentSubset"]>,
+         DeclareOpInterfaceMethods<MemoryEffectsOpInterface, ["getEffects"]>]> {
   let summary = "copy a tensor";
 
   let description = [{
     This op indicates that the data of the `source` tensor should materialize
-    in the future buffer of the `dest` tensors. Both tensors must have the same
-    shape and element type at runtime.
+    in `dest`, which can be a tensor or a memref. In case of a tensor, `source`
+    should materialize in the future buffer of `dest` and a the updated
+    destination tensor is returned. In case of a memref, `source` should
+    materialize in `dest`, which is already a buffer. The op has no results in
+    that case.
+
+    `source`, `dest` and `result` (if present) must have the same shape and
+    element type. If the op has a result, the types of `result` and `dest` must
+    match exactly (e.g., including any tensor encodings).
 
     By default, this op bufferizes to a memcpy from the future buffer of the
-    `source` tensor to the future buffer of the `dest` tensor. However,
-    transformations such as "empty tensor elimination" may rewrite IR such that
-    a computation is performed directly in the future buffer of the `dest`
-    tensor and no memcpy is needed.
-
-    Note: "tensor.insert_slice" could be used for the same purpose, but since
-    tensor dialect ops only indicate *what* should be computed but not *where*,
-    it could fold away, causing the computation to materialize in a different
-    buffer.
+    `source` tensor to the future buffer of the `dest` tensor or to the `dest`
+    buffer. However, transformations such as "empty tensor elimination" may
+    rewrite IR such that a computation is performed directly in `dest` and no
+    memcpy is needed.
+
+    If `dest` is a buffer, the "restrict" and "writable" attributes must be
+    specified. These attributes have the same meaning as the respective
+    attributes of "bufferization.to_tensor". "writable" indicates that the
+    `dest` buffer is considered writable. It does not make sense to materialize
+    a computation in a read-only buffer, so "writable" is required. "restrict"
+    indicates that `dest` does not alias with any memref passed to a "to_tensor"
+    op. Such aliasing is not supported by One-Shot Bufferize.
+
+    Note: "restrict" and "writable" could be removed from this op because they
+    must always be set for memref destinations. This op has these attributes to
+    make clear the requirements on the `dest` operand in the op assembly format.
+    Moreover, these requirements may be relaxed at some point in the future.
+
+    Note: If `dest` is a tensor, "tensor.insert_slice" could be used for the
+    same purpose, but since tensor dialect ops only indicate *what* should be
+    computed but not *where*, it could fold away, causing the computation to
+    materialize in a different buffer.
   }];
 
-  let arguments = (ins AnyTensor:$source, AnyTensor:$dest);
-  let results = (outs AnyTensor:$result);
+  let arguments = (ins AnyTensor:$source, AnyShaped:$dest,
+                       UnitAttr:$restrict, UnitAttr:$writable);
+  let results = (outs Optional<AnyTensor>:$result);
 
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter,
@@ -264,10 +287,23 @@ def Bufferization_MaterializeInDestinationOp
       return ::llvm::cast<RankedTensorType>(getResult().getType());
     }
 
-    MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
+    MutableOperandRange getDpsInitsMutable();
+
+    bool isWritable(Value value, const AnalysisState &state);
   }];
 
-  let assemblyFormat = "$source `in` $dest attr-dict `:` type($source)";
+  let builders = [
+    // Builder that materializes a source tensor in a tensor destination.
+    // Asserts that `dest` has tensor type. Infers the result type of this op
+    // from the destination tensor.
+    OpBuilder<(ins "Value":$source, "Value":$dest)>
+  ];
+
+  let assemblyFormat = [{
+    $source `in` (`restrict` $restrict^)? (`writable` $writable^)? $dest
+        attr-dict `:` functional-type(operands, results)
+  }];
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 7c6c1be351cced1..5b88b0201f05d40 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -542,13 +542,15 @@ bool MaterializeInDestinationOp::bufferizesToMemoryRead(
 
 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
     OpOperand &opOperand, const AnalysisState &state) {
-  return &opOperand == &getDestMutable()[0];
+  return isa<TensorType>(getDest().getType()) &&
+         &opOperand == &getDestMutable()[0];
 }
 
 AliasingValueList
 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
                                               const AnalysisState &state) {
-  if (&opOperand == &getDestMutable()[0])
+  if (isa<TensorType>(getDest().getType()) &&
+      &opOperand == &getDestMutable()[0])
     return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
   return {};
 }
@@ -556,11 +558,20 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
 LogicalResult
 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
                                       const BufferizationOptions &options) {
-  FailureOr<Value> buffer = getBuffer(rewriter, getDest(), options);
-  if (failed(buffer))
-    return failure();
-  rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), *buffer);
-  replaceOpWithBufferizedValues(rewriter, getOperation(), *buffer);
+  bool tensorDest = isa<TensorType>(getDest().getType());
+  Value buffer;
+  if (tensorDest) {
+    FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
+    if (failed(maybeBuffer))
+      return failure();
+    buffer = *maybeBuffer;
+  } else {
+    assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
+    buffer = getDest();
+  }
+  rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), buffer);
+  replaceOpWithBufferizedValues(rewriter, getOperation(),
+                                tensorDest ? ValueRange(buffer) : ValueRange());
   return success();
 }
 
@@ -573,15 +584,29 @@ bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
 
 LogicalResult MaterializeInDestinationOp::reifyResultShapes(
     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
-  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
+  if (getOperation()->getNumResults() == 1) {
+    assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
+    reifiedReturnShapes.resize(1,
+                               SmallVector<OpFoldResult>(getType().getRank()));
+    reifiedReturnShapes[0] =
+        tensor::getMixedSizes(builder, getLoc(), getDest());
+  }
   return success();
 }
 
 Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
                                                         Location loc) {
-  // The subset is the entire destination tensor.
-  return getDest();
+  if (isa<TensorType>(getDest().getType())) {
+    // The subset is the entire destination tensor.
+    return getDest();
+  }
+
+  // Build a bufferization.to_tensor op.
+  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
+  assert(getRestrict() &&
+         "expected that ops with memrefs dest have 'restrict'");
+  return builder.create<ToTensorOp>(loc, getDest(), getRestrict(),
+                                    getWritable());
 }
 
 bool MaterializeInDestinationOp::isEquivalentSubset(
@@ -598,6 +623,51 @@ OpOperand &MaterializeInDestinationOp::getSourceOperand() {
   return getOperation()->getOpOperand(0) /*source*/;
 }
 
+LogicalResult MaterializeInDestinationOp::verify() {
+  if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
+    return emitOpError("'dest' must be a tensor or a memref");
+  if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
+    if (getOperation()->getNumResults() != 1)
+      return emitOpError("tensor 'dest' implies exactly one tensor result");
+    if (destType != getResult().getType())
+      return emitOpError("result and 'dest' types must match");
+  }
+  if (isa<BaseMemRefType>(getDest().getType()) &&
+      getOperation()->getNumResults() != 0)
+    return emitOpError("memref 'dest' implies zero results");
+  if (getRestrict() != isa<BaseMemRefType>(getDest().getType()))
+    return emitOpError("'restrict' must be specified if and only if the "
+                       "destination is of memref type");
+  if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
+    return emitOpError("'writable' must be specified if and only if the "
+                       "destination is of memref type");
+  return success();
+}
+
+void MaterializeInDestinationOp::build(OpBuilder &builder,
+                                       OperationState &state, Value source,
+                                       Value dest) {
+  assert(isa<TensorType>(dest.getType()) && "expected tensor type");
+  build(builder, state, /*result=*/dest.getType(), source, dest);
+}
+
+bool MaterializeInDestinationOp::isWritable(Value value,
+                                            const AnalysisState &state) {
+  return isa<TensorType>(getDest().getType()) ? true : getWritable();
+}
+
+MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
+  return getDestMutable();
+}
+
+void MaterializeInDestinationOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (isa<BaseMemRefType>(getDest().getType()))
+    effects.emplace_back(MemoryEffects::Write::get(), getDest(),
+                         SideEffects::DefaultResource::get());
+}
+
 //===----------------------------------------------------------------------===//
 // ToTensorOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 77ad13dacaa9838..e37c20dc68c88a6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -149,6 +149,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
           op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
       if (!replacement)
         continue;
+
       if (replacement.getType() != v.getType()) {
         rewriter.setInsertionPointAfterValue(replacement);
         replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index a74a3c2c500406f..e6d80a39650ccf0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -248,8 +248,10 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
                LinalgPaddingOptions::CopyBackOp::
                    BufferizationMaterializeInDestination) {
       replacements.push_back(
-          rewriter.create<bufferization::MaterializeInDestinationOp>(
-              loc, std::get<0>(it), std::get<1>(it).get()));
+          rewriter
+              .create<bufferization::MaterializeInDestinationOp>(
+                  loc, std::get<0>(it), std::get<1>(it).get())
+              ->getResult(0));
     } else {
       llvm_unreachable("unsupported copy back op");
     }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index a2fbb06d179ebda..c3e44c426797f39 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -172,7 +172,7 @@ func.func @materialize_in_destination_aliasing(%t: tensor<?xf32>, %p1: index, %p
   %dest = tensor.extract_slice %t[%p2][5][1] : tensor<?xf32> to tensor<5xf32>
   // CHECK: bufferization.materialize_in_destination
   // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
-  %r = bufferization.materialize_in_destination %src in %dest : tensor<5xf32>
+  %r = bufferization.materialize_in_destination %src in %dest : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
   return %r : tensor<5xf32>
 }
 
@@ -183,6 +183,6 @@ func.func @materialize_in_destination(%t: tensor<?xf32>, %sz: index) -> tensor<?
   %buffer = tensor.empty(%sz) : tensor<?xf32>
   // CHECK: bufferization.materialize_in_destination
   // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
-  %r = bufferization.materialize_in_destination %buffer in %buffer : tensor<?xf32>
+  %r = bufferization.materialize_in_destination %buffer in %buffer : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   return %r : tensor<?xf32>
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 41e43047657daff..f3c063826e31a90 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -301,12 +301,25 @@ func.func @regression_multiple_insertion_points(%t1: tensor<?x?xf32>) -> tensor<
 func.func @materialize_in_destination(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> {
   %0 = tensor.empty() : tensor<5xf32>
   %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
-  %1 = bufferization.materialize_in_destination %filled in %t : tensor<5xf32>
+  %1 = bufferization.materialize_in_destination %filled in %t : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
   return %1 : tensor<5xf32>
 }
 
 // -----
 
+// CHECK-LABEL: func @materialize_in_destination_buffer(
+//  CHECK-SAME:     %[[m:.*]]: memref<5xf32>,
+//  CHECK-NEXT:   linalg.fill {{.*}} outs(%[[m]]
+//  CHECK-NEXT:   return
+func.func @materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32) {
+  %0 = tensor.empty() : tensor<5xf32>
+  %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+  bufferization.materialize_in_destination %filled in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @linalg_copy(
 //  CHECK-SAME:     %[[m:.*]]: memref<5xf32, strided<[?], offset: ?>>,
 //       CHECK:   linalg.fill {{.*}} outs(%[[m]]
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index 3f468750cc28405..c8a9b0b1fefc940 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -218,6 +218,19 @@ func.func @tensor_copy(%arg0: tensor<5xf32>) -> tensor<5xf32> {
   // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
   // CHECK: return %[[r]]
   %dest = bufferization.alloc_tensor() : tensor<5xf32>
-  %0 = bufferization.materialize_in_destination %arg0 in %dest : tensor<5xf32>
+  %0 = bufferization.materialize_in_destination %arg0 in %dest
+      : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
   return %0 : tensor<5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @materialize_in_destination_buffer(
+//  CHECK-SAME:     %[[t:.*]]: tensor<5xf32>, %[[m:.*]]: memref<5xf32>)
+//       CHECK:   %[[b:.*]] = bufferization.to_memref %[[t]] : memref<5xf32, strided<[?], offset: ?>>
+//       CHECK:   memref.copy %[[b]], %[[m]]
+func.func @materialize_in_destination_buffer(%t: tensor<5xf32>, %m: memref<5xf32>) {
+  bufferization.materialize_in_destination %t in restrict writable %m
+      : (tensor<5xf32>, memref<5xf32>) -> ()
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 3dfd1eb17e8d64f..5020ab9cb7368b1 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -66,10 +66,51 @@ func.func @invalid_writable_on_op() {
 
 // -----
 
-// expected-note @below{{prior use here}}
 func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
-  // expected-error @below{{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<5xf32>'}}
-  bufferization.materialize_in_destination %arg0 in %arg1 : tensor<?xf32>
+  // expected-error @below{{failed to verify that all of {source, dest} have same shape}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_restrict_missing(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
+  // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, memref<?xf32>) -> ()
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_result(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
+  // expected-error @below{{memref 'dest' implies zero results}}
+  bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, memref<?xf32>) -> (tensor<?xf32>)
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_result_missing(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // expected-error @below{{tensor 'dest' implies exactly one tensor result}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> ()
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
+  bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // expected-error @below{{'writable' must be specified if and only if the destination is of memref type}}
+  bufferization.materialize_in_destination %arg0 in writable %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_result_shape(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // expected-error @below{{result and 'dest' types must match}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<6xf32>)
 }
 
 // -----
diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir
index dc53e535bfe0d57..d4bda0632189d41 100644
--- a/mlir/test/Dialect/Bufferization/ops.mlir
+++ b/mlir/test/Dialect/Bufferization/ops.mlir
@@ -59,10 +59,12 @@ func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) {
 }
 
 // CHECK-LABEL: func @test_materialize_in_destination_op
-func.func @test_materialize_in_destination_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>)
+func.func @test_materialize_in_destination_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>)
     -> tensor<?xf32> {
-  // CHECK: bufferization.materialize_in_destination {{.*}} : tensor<?xf32>
-  %1 = bufferization.materialize_in_destination %arg0 in %arg1 : tensor<?xf32>
+  // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  %1 = bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, memref<?xf32, 3>) -> ()
+  bufferization.materialize_in_destination %arg0 in restrict writable %arg2 : (tensor<?xf32>, memref<?xf32, 3>) -> ()
   return %1 : tensor<?xf32>
 }
 



More information about the Mlir-commits mailing list