[Mlir-commits] [mlir] 56d68e8 - [mlir][bufferization] Add optional `copy` operand to AllocTensorOp

Matthias Springer llvmlistbot at llvm.org
Thu Jun 9 12:41:58 PDT 2022


Author: Matthias Springer
Date: 2022-06-09T21:37:15+02:00
New Revision: 56d68e8d7a17c32d2fd0d0894fc0776df0d85673

URL: https://github.com/llvm/llvm-project/commit/56d68e8d7a17c32d2fd0d0894fc0776df0d85673
DIFF: https://github.com/llvm/llvm-project/commit/56d68e8d7a17c32d2fd0d0894fc0776df0d85673.diff

LOG: [mlir][bufferization] Add optional `copy` operand to AllocTensorOp

If `copy` is specified, the newly allocated buffer is initialized with the given contents. Also add an optional `escape` attribute to indicate whether the buffer of the tensor may be returned from the parent block (aka. "escape") after bufferization.

This change is in preparation of connecting One-Shot Bufferize to the sparse compiler.

Differential Revision: https://reviews.llvm.org/D126570

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/python/mlir/dialects/_bufferization_ops_ext.py
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
    mlir/test/Dialect/Bufferization/canonicalize.mlir
    mlir/test/Dialect/Bufferization/invalid.mlir
    mlir/test/Dialect/Bufferization/ops.mlir
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index f0904932b539a..a0509767cfed8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -24,15 +24,24 @@ class Bufferization_Op<string mnemonic, list<Trait> traits = []>
 //===----------------------------------------------------------------------===//
 
 def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
-    [BufferizableOpInterface,
+    [AttrSizedOperandSegments, BufferizableOpInterface,
      DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
   let summary = "buffer allocation in tensor land";
 
   let description = [{
     `bufferization.alloc_tensor` materializes an uninitialized tensor with a
     given shape (dynamic or static). It always bufferizes to a new buffer
-    allocation of the given shape. Reading from the result of an `alloc_tensor`
-    op yields an undefined value.
+    allocation of the given shape. The optional `copy` operand specifies the
+    contents of the tensors. If no `copy` operand is specified, reading from the
+    result of an `alloc_tensor` op yields an undefined value.
+
+    If `copy` is specified, no dynamic sizes should be passed, since they are
+    the same as the dynamic sizes of the `copy` operand.
+
+    The optional `escape` attribute indicates whether the buffer escapes the
+    parent block or not. In the latter case, the buffer is deallocated at the
+    of the block (during bufferization). In the former case, the buffer is not
+    deallocated and must be deallocated through some other mechanism.
 
     `alloc_tensor` is a helper op for bufferization. The operation is provided
     as an anchor that marks the beginning of a new tensor SSA use-def chain. It
@@ -55,19 +64,25 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     ```
   }];
 
-  let arguments = (ins Variadic<Index>:$dynamicSizes);
+  let arguments = (ins Variadic<Index>:$dynamicSizes,
+                       Optional<AnyTensor>:$copy,
+                       OptionalAttr<BoolAttr>:$escape);
 
   let results = (outs AnyTensor:$result);
 
-  let assemblyFormat = "`(`$dynamicSizes`)` attr-dict `:` type($result)";
-
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter, BufferizationState &state);
 
-    bool isMemoryWrite(OpResult opResult, const AnalysisState &state) const {
-      // AllocTensorOps allocate but do not write.
-      return false;
-    }
+    bool isMemoryWrite(OpResult opResult, const AnalysisState &state);
+
+    bool bufferizesToMemoryRead(OpOperand &opOperand,
+                                const AnalysisState &state);
+
+    bool bufferizesToMemoryWrite(OpOperand &opOperand,
+                                 const AnalysisState &state);
+
+    SmallVector<OpResult> getAliasingOpResult(
+        OpOperand &opOperand, const AnalysisState &state);
 
     RankedTensorType getType() {
       return getResult().getType().cast<RankedTensorType>();
@@ -82,6 +97,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     // the tensor at dimension `idx`. Asserts that the shape is
     // dynamic at that `idx`.
     unsigned getIndexOfDynamicSize(unsigned idx) {
+      assert(!copy() && "no dim sizes specified when copying a tensor");
       assert(isDynamicDim(idx) && "expected dynamic size");
       ArrayRef<int64_t> shape = getType().getShape();
       return std::count_if(
@@ -91,9 +107,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
 
     // Return the Value of the dynamic size of the tensor at dimension
     // `idx`. Asserts that the shape is dynamic at that `idx.
-    Value getDynamicSize(unsigned idx) {
-      return getOperand(getIndexOfDynamicSize(idx));
-    }
+    Value getDynamicSize(OpBuilder &b, unsigned idx);
 
     // Assert that the size of the result tensor is static at `idx`
     // and return the shape.
@@ -103,7 +117,21 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     }
   }];
 
+  let builders = [
+    // Build an op without `copy` operand and `escape` attribute.
+    OpBuilder<(ins "RankedTensorType":$type, "ValueRange":$dynamicSizes)>,
+
+    // Build an op without `escape` attribute.
+    OpBuilder<(ins "RankedTensorType":$type, "ValueRange":$dynamicSizes,
+                   "Value":$copy)>,
+
+    // Build an op with `copy` and `escape` attribute.
+    OpBuilder<(ins "RankedTensorType":$type, "ValueRange":$dynamicSizes,
+                   "Value":$copy, "bool":$escape)>,
+  ];
+
   let hasCanonicalizer = 1;
+  let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index d9eb0bb2118db..0025215882db1 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -139,21 +139,83 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
   if (getOperation()->getUses().empty())
     return success();
 
-  FailureOr<Value> alloc = state.createAlloc(rewriter, getLoc(), getResult());
+  Optional<bool> dealloc = llvm::None;
+  if (escape().hasValue())
+    dealloc = !*escape();
+  FailureOr<Value> alloc =
+      state.createAlloc(rewriter, getLoc(), getResult(), dealloc);
   if (failed(alloc))
     return failure();
+  if (copy()) {
+    FailureOr<Value> copyValueBuffer = state.getBuffer(
+        rewriter, getOperation()->getOpOperand(getNumOperands() - 1));
+    if (failed(copyValueBuffer))
+      return failure();
+    if (failed(state.getOptions().createMemCpy(rewriter, getLoc(),
+                                               *copyValueBuffer, *alloc)))
+      return failure();
+  }
   replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
   return success();
 }
 
+bool AllocTensorOp::isMemoryWrite(OpResult opResult,
+                                  const AnalysisState &state) {
+  // AllocTensorOps do not write unless they have a `copy` value.
+  return static_cast<bool>(copy());
+}
+
+bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
+                                           const AnalysisState &state) {
+  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
+         "expected copy operand");
+  return true;
+}
+
+bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
+                                            const AnalysisState &state) {
+  assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
+         "expected copy operand");
+  return false;
+}
+
+SmallVector<OpResult>
+AllocTensorOp::getAliasingOpResult(OpOperand &opOperand,
+                                   const AnalysisState &state) {
+  // This is a new allocation. It does not alias with any other buffer.
+  return {};
+}
+
 LogicalResult AllocTensorOp::verify() {
-  if (getType().getNumDynamicDims() !=
-      static_cast<int64_t>(dynamicSizes().size()))
+  if (copy() && !dynamicSizes().empty())
+    return emitError("dynamic sizes not needed when copying a tensor");
+  if (!copy() && getType().getNumDynamicDims() !=
+                     static_cast<int64_t>(dynamicSizes().size()))
     return emitError("expected ")
            << getType().getNumDynamicDims() << " dynamic sizes";
+  if (copy() && copy().getType() != getType())
+    return emitError("expected that `copy` and return type match");
   return success();
 }
 
+void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
+                          RankedTensorType type, ValueRange dynamicSizes) {
+  build(builder, result, type, dynamicSizes, /*copy=*/Value(),
+        /*escape=*/BoolAttr());
+}
+
+void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
+                          RankedTensorType type, ValueRange dynamicSizes,
+                          Value copy) {
+  build(builder, result, type, dynamicSizes, copy, /*escape=*/BoolAttr());
+}
+
+void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
+                          RankedTensorType type, ValueRange dynamicSizes,
+                          Value copy, bool escape) {
+  build(builder, result, type, dynamicSizes, copy, builder.getBoolAttr(escape));
+}
+
 namespace {
 /// Change the type of the result of a `bufferization.alloc_tensor` by making
 /// the result type statically sized along dimension that in the original
@@ -171,6 +233,8 @@ struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
 
   LogicalResult matchAndRewrite(AllocTensorOp op,
                                 PatternRewriter &rewriter) const override {
+    if (op.copy())
+      return failure();
     SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
     SmallVector<Value> newDynamicSizes;
     unsigned int dynValCounter = 0;
@@ -189,8 +253,9 @@ struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
         newShape, op.getType().getElementType(), op.getType().getEncoding());
     if (newType == op.getType())
       return failure();
-    auto newOp =
-        rewriter.create<AllocTensorOp>(op.getLoc(), newType, newDynamicSizes);
+    auto newOp = rewriter.create<AllocTensorOp>(
+        op.getLoc(), newType, newDynamicSizes, /*copy=*/Value(),
+        /*escape=*/op.escapeAttr());
     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
     return success();
   }
@@ -207,8 +272,8 @@ struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
       return failure();
     if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
       return failure();
-    rewriter.replaceOp(dimOp,
-                       allocTensorOp.getDynamicSize(*maybeConstantIndex));
+    rewriter.replaceOp(
+        dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
     return success();
   }
 };
@@ -224,7 +289,7 @@ LogicalResult AllocTensorOp::reifyResultShapes(
   auto shapes = llvm::to_vector<4>(llvm::map_range(
       llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
         if (isDynamicDim(dim))
-          return getDynamicSize(dim);
+          return getDynamicSize(builder, dim);
         return builder.create<arith::ConstantIndexOp>(getLoc(),
                                                       getStaticSize(dim));
       }));
@@ -232,6 +297,59 @@ LogicalResult AllocTensorOp::reifyResultShapes(
   return success();
 }
 
+ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
+  SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
+  if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
+      parser.parseRParen())
+    return failure();
+  ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
+  OpAsmParser::UnresolvedOperand copyOperand;
+  if (copyKeyword.succeeded())
+    if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
+        parser.parseRParen())
+      return failure();
+  if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
+    return failure();
+
+  TensorType type;
+  if (parser.parseCustomTypeWithFallback(type))
+    return failure();
+  result.addTypes(type);
+
+  Type indexType = parser.getBuilder().getIndexType();
+  if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
+    return failure();
+  if (copyKeyword.succeeded())
+    if (parser.resolveOperand(copyOperand, type, result.operands))
+      return failure();
+  result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
+                      parser.getBuilder().getI32VectorAttr(
+                          {static_cast<int32_t>(dynamicSizesOperands.size()),
+                           static_cast<int32_t>(copyKeyword.succeeded())}));
+  return success();
+}
+
+void AllocTensorOp::print(OpAsmPrinter &p) {
+  p << "(" << dynamicSizes() << ")";
+  if (copy())
+    p << " copy(" << copy() << ")";
+  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
+                              AllocTensorOp::getOperandSegmentSizeAttr()});
+  p << " : ";
+  auto type = result().getType();
+  if (auto validType = type.dyn_cast<::mlir::TensorType>())
+    p.printStrippedAttrOrType(validType);
+  else
+    p << type;
+}
+
+Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
+  assert(isDynamicDim(idx) && "expected dynamic dim");
+  if (copy())
+    return b.create<tensor::DimOp>(getLoc(), copy(), idx);
+  return getOperand(getIndexOfDynamicSize(idx));
+}
+
 //===----------------------------------------------------------------------===//
 // CloneOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py
index c720844af2ebd..23f78fc80aec3 100644
--- a/mlir/python/mlir/dialects/_bufferization_ops_ext.py
+++ b/mlir/python/mlir/dialects/_bufferization_ops_ext.py
@@ -18,15 +18,20 @@ class AllocTensorOp:
   def __init__(self,
                tensor_type: Type,
                dynamic_sizes: Sequence[Value],
+               copy: Value,
+               escape: BoolAttr,
                *,
                loc=None,
                ip=None):
     """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
     context = get_default_loc_context(loc)
+    attributes = {}
+    if escape:
+      attributes["escape"] = escape
     op = self.build_generic(
         results=[tensor_type],
-        operands=dynamic_sizes,
-        attributes={},
+        operands=[dynamic_sizes, copy],
+        attributes=attributes,
         loc=loc,
         ip=ip)
     OpView.__init__(self, op)

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index e2cda814a8d15..26faa04696eae 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -119,3 +119,22 @@ func.func @select_
diff erent_tensors(%t: tensor<?xf32>, %sz: index, %c: i1) -> te
   %1 = arith.select %c, %0, %t : tensor<?xf32>
   return %1 : tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @alloc_tensor_with_copy(
+//  CHECK-SAME:     %[[t:.*]]: tensor<5xf32>)
+// TODO: Add a test case with dynamic dim size. This is not possible at the
+// moment because this would create a tensor op during bufferization. That is
+// currently forbidden.
+func.func @alloc_tensor_with_copy(%t: tensor<5xf32>) -> tensor<5xf32> {
+  // CHECK: %[[m:.*]] = bufferization.to_memref %[[t]]
+  // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
+  // CHECK: memref.copy %[[m]], %[[alloc]]
+  %0 = bufferization.alloc_tensor() copy(%t) : tensor<5xf32>
+  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
+  // CHECK: memref.dealloc %[[alloc]]
+  // CHECK: return %[[r]]
+  return %0 : tensor<5xf32>
+}
+

diff  --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index c41cebd77a9ae..f76f040937aaf 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -224,7 +224,7 @@ func.func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
   return %1 : memref<?x?x16x32xi8>
 }
 // CHECK:   %[[M:.+]] = bufferization.to_memref %[[ARG0]] : memref<4x6x16x32xi8>
-// CHECK:   %[[M1:.+]] = memref.cast %[[M]] 
+// CHECK:   %[[M1:.+]] = memref.cast %[[M]]
 // CHECK-SAME: memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
 // CHECK:   return %[[M1]] : memref<?x?x16x32xi8>
 

diff  --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 9e732b9bc6e48..64d873abb9a1d 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -1,8 +1,33 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
-func.func @alloc_tensor_err(%arg0 : index)
+func.func @alloc_tensor_missing_dims(%arg0: index)
 {
   // expected-error @+1 {{expected 2 dynamic sizes}}
-  %1 = bufferization.alloc_tensor(%arg0) : tensor<4x?x?x5xf32>
+  %0 = bufferization.alloc_tensor(%arg0) : tensor<4x?x?x5xf32>
+  return
+}
+
+// -----
+
+// expected-note @+1 {{prior use here}}
+func.func @alloc_tensor_type_mismatch(%t: tensor<?xf32>) {
+  // expected-error @+1{{expects 
diff erent type than prior uses: 'tensor<5xf32>' vs 'tensor<?xf32>'}}
+  %0 = bufferization.alloc_tensor() copy(%t) : tensor<5xf32>
+  return
+}
+
+// -----
+
+func.func @alloc_tensor_copy_and_dims(%t: tensor<?xf32>, %sz: index) {
+  // expected-error @+1{{dynamic sizes not needed when copying a tensor}}
+  %0 = bufferization.alloc_tensor(%sz) copy(%t) : tensor<?xf32>
+  return
+}
+
+// -----
+
+func.func @alloc_tensor_invalid_escape_attr(%sz: index) {
+  // expected-error @+1{{op attribute 'escape' failed to satisfy constraint: bool attribute}}
+  %0 = bufferization.alloc_tensor(%sz) {escape = 5} : tensor<?xf32>
   return
 }

diff  --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir
index 23ec897df9d45..393c9ae8ddb63 100644
--- a/mlir/test/Dialect/Bufferization/ops.mlir
+++ b/mlir/test/Dialect/Bufferization/ops.mlir
@@ -22,3 +22,22 @@ func.func @test_to_tensor(%buf : memref<2xf32>) -> tensor<2xf32> {
   %tensor = bufferization.to_tensor %buf : memref<2xf32>
   return %tensor : tensor<2xf32>
 }
+
+// CHECK-LABEL: func @test_alloc_tensor_op
+func.func @test_alloc_tensor_op(%t: tensor<?x5xf32>, %sz: index)
+  -> tensor<?x5xf32>
+{
+  // CHECK: bufferization.alloc_tensor(%{{.*}}) : tensor<?x5xf32>
+  %0 = bufferization.alloc_tensor(%sz) : tensor<?x5xf32>
+  // CHECK: bufferization.alloc_tensor() copy(%{{.*}}) : tensor<?x5xf32>
+  %1 = bufferization.alloc_tensor() copy(%t) : tensor<?x5xf32>
+  // CHECK: bufferization.alloc_tensor() : tensor<5x6xf32>
+  %2 = bufferization.alloc_tensor() : tensor<5x6xf32>
+  // CHECK: bufferization.alloc_tensor(%{{.*}}, %{{.*}}) : tensor<?x?xf32>
+  %3 = bufferization.alloc_tensor(%sz, %sz) : tensor<?x?xf32>
+  // CHECK: bufferization.alloc_tensor() copy(%{{.*}}) {escape = true} : tensor<?x5xf32>
+  %4 = bufferization.alloc_tensor() copy(%t) {escape = true} : tensor<?x5xf32>
+  // CHECK: bufferization.alloc_tensor() copy(%{{.*}}) {escape = false} : tensor<?x5xf32>
+  %5 = bufferization.alloc_tensor() copy(%t) {escape = false} : tensor<?x5xf32>
+  return %1 : tensor<?x5xf32>
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index 4b8cfeb1bd36d..9bab366fcbe82 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -890,7 +890,7 @@ def emit_tensor_init(self) -> ir.RankedTensorType:
     mlir_type = _mlir_tensor_type(self.dst_dtype, self.dst_dims,
                                   self.dst_format.mlir_tensor_attr())
     index_type = ir.IndexType.get()
-    return bufferization.AllocTensorOp(mlir_type, [])
+    return bufferization.AllocTensorOp(mlir_type, [], None, None)
 
 
 class _Stats:


        


More information about the Mlir-commits mailing list