[Mlir-commits] [mlir] 0fcaca2 - [mlir][bufferization] `MaterializeInDestinationOp`: Support memref destinations (#68074)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Fri Oct  6 02:57:15 PDT 2023
    
    
  
Author: Matthias Springer
Date: 2023-10-06T11:57:10+02:00
New Revision: 0fcaca2feaa973afa9275c0cb931775f5e3bde4c
URL: https://github.com/llvm/llvm-project/commit/0fcaca2feaa973afa9275c0cb931775f5e3bde4c
DIFF: https://github.com/llvm/llvm-project/commit/0fcaca2feaa973afa9275c0cb931775f5e3bde4c.diff
LOG: [mlir][bufferization] `MaterializeInDestinationOp`: Support memref destinations (#68074)
Extend `bufferization.materialize_in_destination` to support memref
destinations. This op can now be used to indicate that a tensor
computation should materialize in a given buffer (that may have been
allocated by another component/runtime). The op still participates in
"empty tensor elimination".
Example:
```mlir
func.func @test(%out: memref<10xf32>) {
  %t = tensor.empty() : tensor<10xf32>
  %c = linalg.generic ... outs(%t: tensor<10xf32>) -> tensor<10xf32>
  bufferization.materialize_in_destination %c in restrict writable %out : (tensor<10xf32>, memref<10xf32>) -> ()
  return
}
```
After "empty tensor elimination", the above IR can bufferize without an
allocation:
```mlir
func.func @test(%out: memref<10xf32>) {
  linalg.generic ... outs(%out: memref<10xf32>)
  return
}
```
This change also clarifies the meaning of the `restrict` unit attribute
on `bufferization.to_tensor` ops.
Added: 
    
Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
    mlir/test/Dialect/Bufferization/invalid.mlir
    mlir/test/Dialect/Bufferization/ops.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index db93a51775ffcd7..09ce2981d382680 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -216,33 +216,58 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
 
 def Bufferization_MaterializeInDestinationOp
     : Bufferization_Op<"materialize_in_destination",
-        [BufferizableOpInterface, SameOperandsAndResultType,
-         DestinationStyleOpInterface,
+        [AllShapesMatch<["source", "dest"]>,
+         AllElementTypesMatch<["source", "dest"]>,
+         BufferizableOpInterface, DestinationStyleOpInterface,
          DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
          DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
             ["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
-             "buildSubsetExtraction", "isEquivalentSubset"]>]> {
+             "buildSubsetExtraction", "isEquivalentSubset"]>,
+         DeclareOpInterfaceMethods<MemoryEffectsOpInterface, ["getEffects"]>]> {
   let summary = "copy a tensor";
 
   let description = [{
     This op indicates that the data of the `source` tensor should materialize
-    in the future buffer of the `dest` tensors. Both tensors must have the same
-    shape and element type at runtime.
+    in `dest`, which can be a tensor or a memref. In case of a tensor, `source`
+    should materialize in the future buffer of `dest` and a the updated
+    destination tensor is returned. In case of a memref, `source` should
+    materialize in `dest`, which is already a buffer. The op has no results in
+    that case.
+
+    `source`, `dest` and `result` (if present) must have the same shape and
+    element type. If the op has a result, the types of `result` and `dest` must
+    match exactly (e.g., including any tensor encodings).
 
     By default, this op bufferizes to a memcpy from the future buffer of the
-    `source` tensor to the future buffer of the `dest` tensor. However,
-    transformations such as "empty tensor elimination" may rewrite IR such that
-    a computation is performed directly in the future buffer of the `dest`
-    tensor and no memcpy is needed.
-
-    Note: "tensor.insert_slice" could be used for the same purpose, but since
-    tensor dialect ops only indicate *what* should be computed but not *where*,
-    it could fold away, causing the computation to materialize in a 
diff erent
-    buffer.
+    `source` tensor to the future buffer of the `dest` tensor or to the `dest`
+    buffer. However, transformations such as "empty tensor elimination" may
+    rewrite IR such that a computation is performed directly in `dest` and no
+    memcpy is needed.
+
+    If `dest` is a buffer, the `restrict` and `writable` attributes must be
+    specified. These attributes have the same meaning as the respective
+    attributes of `bufferization.to_tensor`. `writable` indicates that the
+    `dest` buffer is considered writable. It does not make sense to materialize
+    a computation in a read-only buffer, so `writable` is required. `restrict`
+    indicates that this op is the only way for the tensor IR to access `dest`
+    (or an alias thereof). E.g., there must be no other `to_tensor` ops with
+    `dest` or with an alias of `dest`. Such IR is not supported by
+    One-Shot Bufferize.
+
+    Note: `restrict` and `writable` could be removed from this op because they
+    must always be set for memref destinations. This op has these attributes to
+    make clear the requirements on the `dest` operand in the op assembly format.
+    Moreover, these requirements may be relaxed at some point in the future.
+
+    Note: If `dest` is a tensor, `tensor.insert_slice` could be used for the
+    same purpose, but since tensor dialect ops only indicate *what* should be
+    computed but not *where*, it could fold away, causing the computation to
+    materialize in a 
diff erent buffer.
   }];
 
-  let arguments = (ins AnyTensor:$source, AnyTensor:$dest);
-  let results = (outs AnyTensor:$result);
+  let arguments = (ins AnyTensor:$source, AnyShaped:$dest,
+                       UnitAttr:$restrict, UnitAttr:$writable);
+  let results = (outs Optional<AnyTensor>:$result);
 
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter,
@@ -264,10 +289,23 @@ def Bufferization_MaterializeInDestinationOp
       return ::llvm::cast<RankedTensorType>(getResult().getType());
     }
 
-    MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
+    MutableOperandRange getDpsInitsMutable();
+
+    bool isWritable(Value value, const AnalysisState &state);
   }];
 
-  let assemblyFormat = "$source `in` $dest attr-dict `:` type($source)";
+  let builders = [
+    // Builder that materializes a source tensor in a tensor destination.
+    // Asserts that `dest` has tensor type. Infers the result type of this op
+    // from the destination tensor.
+    OpBuilder<(ins "Value":$source, "Value":$dest)>
+  ];
+
+  let assemblyFormat = [{
+    $source `in` (`restrict` $restrict^)? (`writable` $writable^)? $dest
+        attr-dict `:` functional-type(operands, results)
+  }];
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -361,10 +399,15 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     thereof) will bufferize out-of-place to prevent emitting any writes to
     `memref` during bufferization.
 
-    If the given memref does not alias with any other memref passed to another
-    `to_tensor` op, the `restrict` unit attribute can be set. Only such
-    operations are supported by One-Shot Bufferize. (Otherwise, potential memref
-    aliasing relationships would have to be captured in One-Shot Bufferize.)
+    The `restrict` unit attribute (similar to the C `restrict` keyword)
+    indicates that the produced tensor result is the only way for the tensor
+    IR to gain access to the `memref` operand (or an alias thereof). E.g.,
+    there must be no other `to_tensor` op with the same or with an aliasing
+    `memref` operand.
+
+    Note: Only `to_tensor` ops with the `restrict` unit attribute are supported
+    by One-Shot Bufferize. Other IR is rejected. (To support `to_tensor`
+    without `restrict`, One-Shot Bufferize would have to analyze memref IR.)
 
     Example:
 
diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 01cbacc96fd42d2..1c33f444d15850c 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -542,25 +542,40 @@ bool MaterializeInDestinationOp::bufferizesToMemoryRead(
 
 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
     OpOperand &opOperand, const AnalysisState &state) {
-  return &opOperand == &getDestMutable();
+  if (&opOperand == &getDestMutable()) {
+    assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
+    return true;
+  }
+  return false;
 }
 
 AliasingValueList
 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
                                               const AnalysisState &state) {
-  if (&opOperand == &getDestMutable())
+  if (&opOperand == &getDestMutable()) {
+    assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
     return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
+  }
   return {};
 }
 
 LogicalResult
 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
                                       const BufferizationOptions &options) {
-  FailureOr<Value> buffer = getBuffer(rewriter, getDest(), options);
-  if (failed(buffer))
-    return failure();
-  rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), *buffer);
-  replaceOpWithBufferizedValues(rewriter, getOperation(), *buffer);
+  bool tensorDest = isa<TensorType>(getDest().getType());
+  Value buffer;
+  if (tensorDest) {
+    FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
+    if (failed(maybeBuffer))
+      return failure();
+    buffer = *maybeBuffer;
+  } else {
+    assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
+    buffer = getDest();
+  }
+  rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), buffer);
+  replaceOpWithBufferizedValues(rewriter, getOperation(),
+                                tensorDest ? ValueRange(buffer) : ValueRange());
   return success();
 }
 
@@ -573,15 +588,29 @@ bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
 
 LogicalResult MaterializeInDestinationOp::reifyResultShapes(
     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
-  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
+  if (getOperation()->getNumResults() == 1) {
+    assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
+    reifiedReturnShapes.resize(1,
+                               SmallVector<OpFoldResult>(getType().getRank()));
+    reifiedReturnShapes[0] =
+        tensor::getMixedSizes(builder, getLoc(), getDest());
+  }
   return success();
 }
 
 Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
                                                         Location loc) {
-  // The subset is the entire destination tensor.
-  return getDest();
+  if (isa<TensorType>(getDest().getType())) {
+    // The subset is the entire destination tensor.
+    return getDest();
+  }
+
+  // Build a bufferization.to_tensor op.
+  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
+  assert(getRestrict() &&
+         "expected that ops with memrefs dest have 'restrict'");
+  return builder.create<ToTensorOp>(loc, getDest(), getRestrict(),
+                                    getWritable());
 }
 
 bool MaterializeInDestinationOp::isEquivalentSubset(
@@ -598,6 +627,51 @@ OpOperand &MaterializeInDestinationOp::getSourceOperand() {
   return getOperation()->getOpOperand(0) /*source*/;
 }
 
+LogicalResult MaterializeInDestinationOp::verify() {
+  if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
+    return emitOpError("'dest' must be a tensor or a memref");
+  if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
+    if (getOperation()->getNumResults() != 1)
+      return emitOpError("tensor 'dest' implies exactly one tensor result");
+    if (destType != getResult().getType())
+      return emitOpError("result and 'dest' types must match");
+  }
+  if (isa<BaseMemRefType>(getDest().getType()) &&
+      getOperation()->getNumResults() != 0)
+    return emitOpError("memref 'dest' implies zero results");
+  if (getRestrict() != isa<BaseMemRefType>(getDest().getType()))
+    return emitOpError("'restrict' must be specified if and only if the "
+                       "destination is of memref type");
+  if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
+    return emitOpError("'writable' must be specified if and only if the "
+                       "destination is of memref type");
+  return success();
+}
+
+void MaterializeInDestinationOp::build(OpBuilder &builder,
+                                       OperationState &state, Value source,
+                                       Value dest) {
+  assert(isa<TensorType>(dest.getType()) && "expected tensor type");
+  build(builder, state, /*result=*/dest.getType(), source, dest);
+}
+
+bool MaterializeInDestinationOp::isWritable(Value value,
+                                            const AnalysisState &state) {
+  return isa<TensorType>(getDest().getType()) ? true : getWritable();
+}
+
+MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
+  return getDestMutable();
+}
+
+void MaterializeInDestinationOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (isa<BaseMemRefType>(getDest().getType()))
+    effects.emplace_back(MemoryEffects::Write::get(), getDest(),
+                         SideEffects::DefaultResource::get());
+}
+
 //===----------------------------------------------------------------------===//
 // ToTensorOp
 //===----------------------------------------------------------------------===//
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index a74a3c2c500406f..e6d80a39650ccf0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -248,8 +248,10 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
                LinalgPaddingOptions::CopyBackOp::
                    BufferizationMaterializeInDestination) {
       replacements.push_back(
-          rewriter.create<bufferization::MaterializeInDestinationOp>(
-              loc, std::get<0>(it), std::get<1>(it).get()));
+          rewriter
+              .create<bufferization::MaterializeInDestinationOp>(
+                  loc, std::get<0>(it), std::get<1>(it).get())
+              ->getResult(0));
     } else {
       llvm_unreachable("unsupported copy back op");
     }
diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index a2fbb06d179ebda..c3e44c426797f39 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -172,7 +172,7 @@ func.func @materialize_in_destination_aliasing(%t: tensor<?xf32>, %p1: index, %p
   %dest = tensor.extract_slice %t[%p2][5][1] : tensor<?xf32> to tensor<5xf32>
   // CHECK: bufferization.materialize_in_destination
   // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
-  %r = bufferization.materialize_in_destination %src in %dest : tensor<5xf32>
+  %r = bufferization.materialize_in_destination %src in %dest : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
   return %r : tensor<5xf32>
 }
 
@@ -183,6 +183,6 @@ func.func @materialize_in_destination(%t: tensor<?xf32>, %sz: index) -> tensor<?
   %buffer = tensor.empty(%sz) : tensor<?xf32>
   // CHECK: bufferization.materialize_in_destination
   // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
-  %r = bufferization.materialize_in_destination %buffer in %buffer : tensor<?xf32>
+  %r = bufferization.materialize_in_destination %buffer in %buffer : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   return %r : tensor<?xf32>
 }
diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index b68682a459ed2c2..99b974b9ef3c67e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -301,12 +301,25 @@ func.func @regression_multiple_insertion_points(%t1: tensor<?x?xf32>) -> tensor<
 func.func @materialize_in_destination(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> {
   %0 = tensor.empty() : tensor<5xf32>
   %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
-  %1 = bufferization.materialize_in_destination %filled in %t : tensor<5xf32>
+  %1 = bufferization.materialize_in_destination %filled in %t : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
   return %1 : tensor<5xf32>
 }
 
 // -----
 
+// CHECK-LABEL: func @materialize_in_destination_buffer(
+//  CHECK-SAME:     %[[m:.*]]: memref<5xf32>,
+//  CHECK-NEXT:   linalg.fill {{.*}} outs(%[[m]]
+//  CHECK-NEXT:   return
+func.func @materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32) {
+  %0 = tensor.empty() : tensor<5xf32>
+  %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+  bufferization.materialize_in_destination %filled in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @linalg_copy(
 //  CHECK-SAME:     %[[m:.*]]: memref<5xf32, strided<[?], offset: ?>>,
 //       CHECK:   linalg.fill {{.*}} outs(%[[m]]
diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index 3f468750cc28405..272423de5564b09 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -218,6 +218,20 @@ func.func @tensor_copy(%arg0: tensor<5xf32>) -> tensor<5xf32> {
   // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
   // CHECK: return %[[r]]
   %dest = bufferization.alloc_tensor() : tensor<5xf32>
-  %0 = bufferization.materialize_in_destination %arg0 in %dest : tensor<5xf32>
+  %0 = bufferization.materialize_in_destination %arg0 in %dest
+      : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
   return %0 : tensor<5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @materialize_in_destination_buffer(
+//  CHECK-SAME:     %[[t:.*]]: tensor<5xf32>, %[[m:.*]]: memref<5xf32>)
+//       CHECK:   %[[b:.*]] = bufferization.to_memref %[[t]] : memref<5xf32, strided<[?], offset: ?>>
+//       CHECK:   memref.copy %[[b]], %[[m]]
+func.func @materialize_in_destination_buffer(%t: tensor<5xf32>, %m: memref<5xf32>) {
+  bufferization.materialize_in_destination %t in restrict writable %m
+      : (tensor<5xf32>, memref<5xf32>) -> ()
+  return
+}
+
diff  --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 8004ec632453e8c..ce56f89c1f1bbe6 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -66,10 +66,58 @@ func.func @invalid_writable_on_op() {
 
 // -----
 
-// expected-note @below{{prior use here}}
 func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
-  // expected-error @below{{expects 
diff erent type than prior uses: 'tensor<?xf32>' vs 'tensor<5xf32>'}}
-  bufferization.materialize_in_destination %arg0 in %arg1 : tensor<?xf32>
+  // expected-error @below{{failed to verify that all of {source, dest} have same shape}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_dest_type(%arg0: tensor<5xf32>, %arg1: vector<5xf32>) {
+  // expected-error @below{{'dest' must be a tensor or a memref}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5xf32>, vector<5xf32>) -> ()
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_restrict_missing(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
+  // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, memref<?xf32>) -> ()
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_result(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
+  // expected-error @below{{memref 'dest' implies zero results}}
+  bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, memref<?xf32>) -> (tensor<?xf32>)
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_result_missing(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // expected-error @below{{tensor 'dest' implies exactly one tensor result}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> ()
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
+  bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // expected-error @below{{'writable' must be specified if and only if the destination is of memref type}}
+  bufferization.materialize_in_destination %arg0 in writable %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_result_shape(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // expected-error @below{{result and 'dest' types must match}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<6xf32>)
 }
 
 // -----
diff  --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir
index dc53e535bfe0d57..d4bda0632189d41 100644
--- a/mlir/test/Dialect/Bufferization/ops.mlir
+++ b/mlir/test/Dialect/Bufferization/ops.mlir
@@ -59,10 +59,12 @@ func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) {
 }
 
 // CHECK-LABEL: func @test_materialize_in_destination_op
-func.func @test_materialize_in_destination_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>)
+func.func @test_materialize_in_destination_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>)
     -> tensor<?xf32> {
-  // CHECK: bufferization.materialize_in_destination {{.*}} : tensor<?xf32>
-  %1 = bufferization.materialize_in_destination %arg0 in %arg1 : tensor<?xf32>
+  // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  %1 = bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, memref<?xf32, 3>) -> ()
+  bufferization.materialize_in_destination %arg0 in restrict writable %arg2 : (tensor<?xf32>, memref<?xf32, 3>) -> ()
   return %1 : tensor<?xf32>
 }
 
        
    
    
More information about the Mlir-commits
mailing list