[Mlir-commits] [mlir] 7a3a253 - [MLIR][BufferPlacement] Support functions that return Memref typed results

Ehsan Toosi llvmlistbot at llvm.org
Fri May 29 02:04:06 PDT 2020


Author: Ehsan Toosi
Date: 2020-05-29T11:03:22+02:00
New Revision: 7a3a2535854c84b1c8f6b0a2f2677e89b0e1a250

URL: https://github.com/llvm/llvm-project/commit/7a3a2535854c84b1c8f6b0a2f2677e89b0e1a250
DIFF: https://github.com/llvm/llvm-project/commit/7a3a2535854c84b1c8f6b0a2f2677e89b0e1a250.diff

LOG: [MLIR][BufferPlacement] Support functions that return Memref typed results

Buffer placement can now operates on functions that return buffers. These
buffers escape from the deallocation phase of buffer placement.

Differential Revision: https://reviews.llvm.org/D80696

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/BufferPlacement.h
    mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
    mlir/lib/Transforms/BufferPlacement.cpp
    mlir/test/Transforms/buffer-placement-preparation.mlir
    mlir/test/Transforms/buffer-placement.mlir
    mlir/test/lib/Transforms/TestBufferPlacement.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h
index 030b87599d06..10949160fcbd 100644
--- a/mlir/include/mlir/Transforms/BufferPlacement.h
+++ b/mlir/include/mlir/Transforms/BufferPlacement.h
@@ -76,11 +76,23 @@ class BufferAssignmentOpConversionPattern
   TypeConverter *converter;
 };
 
-/// Converts the signature of the function using the type converter.
-/// It adds an extra argument for each illegally-typed function
-/// result to the function arguments. `BufferAssignmentTypeConverter`
-/// is a helper `TypeConverter` for this purpose. All the non-shaped types
-/// of the input function will be converted to memref.
+/// A helper type converter class for using inside Buffer Assignment operation
+/// conversion patterns. The default constructor keeps all the types intact
+/// except for the ranked-tensor types which is converted to memref types.
+class BufferAssignmentTypeConverter : public TypeConverter {
+public:
+  BufferAssignmentTypeConverter();
+
+  /// A helper function to check if `type` has been converted from non-memref
+  /// type to memref.
+  static bool isConvertedMemref(Type type, Type before);
+};
+
+/// Converts the signature of the function using the type converter. It adds an
+/// extra argument for each function result type which is going to be a memref
+/// type after type conversion. The other function result types remain
+/// unchanged. `BufferAssignmentTypeConverter` is a helper `TypeConverter` for
+/// this purpose.
 class FunctionAndBlockSignatureConverter
     : public BufferAssignmentOpConversionPattern<FuncOp> {
 public:
@@ -93,12 +105,14 @@ class FunctionAndBlockSignatureConverter
                   ConversionPatternRewriter &rewriter) const final;
 };
 
-/// Converts the source `ReturnOp` to target `ReturnOp`, removes all
-/// the buffer operands from the operands list, and inserts `CopyOp`s
-/// for all buffer operands instead.
+/// Rewrites the `ReturnOp` to conform with the changed function signature.
+/// Operands that correspond to return values that have been rewritten from
+/// tensor results to memref arguments are dropped. In their place, a
+/// corresponding copy operation from the operand to the new function argument
+/// is inserted.
 template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
           typename CopyOpTy>
-class NoBufferOperandsReturnOpConverter
+class BufferAssignmentReturnOpConverter
     : public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> {
 public:
   using BufferAssignmentOpConversionPattern<
@@ -108,50 +122,41 @@ class NoBufferOperandsReturnOpConverter
   LogicalResult
   matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
+    // Split the operands by their kinds whether they are converted memref or
+    // not.
+    SmallVector<Value, 2> needCopyOperands, newOperands;
+    unsigned operandsSize = operands.size();
+    needCopyOperands.reserve(operandsSize);
+    newOperands.reserve(operandsSize);
+    for (auto operand : llvm::enumerate(operands))
+      if (BufferAssignmentTypeConverter::isConvertedMemref(
+              operand.value().getType(),
+              returnOp.getOperand(operand.index()).getType()))
+        needCopyOperands.push_back(operand.value());
+      else
+        newOperands.push_back(operand.value());
+
     Block &entryBlock = returnOp.getParentRegion()->front();
     unsigned numFuncArgs = entryBlock.getNumArguments();
-    Location loc = returnOp.getLoc();
-
-    // The target `ReturnOp` should not contain any memref operands.
-    SmallVector<Value, 2> newOperands(operands.begin(), operands.end());
-    llvm::erase_if(newOperands, [](Value operand) {
-      return operand.getType().isa<MemRefType>();
-    });
 
     // Find the index of the first destination buffer.
-    unsigned numBufferOperands = operands.size() - newOperands.size();
-    unsigned destArgNum = numFuncArgs - numBufferOperands;
-
+    assert(needCopyOperands.size() <= numFuncArgs &&
+           "The number of operands of return operation is more than the "
+           "number of function arguments.");
+    unsigned destArgNum = numFuncArgs - needCopyOperands.size();
     rewriter.setInsertionPoint(returnOp);
-    // Find the corresponding destination buffer for each memref operand.
-    for (Value operand : operands)
-      if (operand.getType().isa<MemRefType>()) {
-        assert(destArgNum < numFuncArgs &&
-               "The number of operands of return operation is more than the "
-               "number of function argument.");
-
-        // For each memref type operand of the source `ReturnOp`, a new `CopyOp`
-        // is inserted that copies the buffer content from the operand to the
-        // target.
-        rewriter.create<CopyOpTy>(loc, operand,
-                                  entryBlock.getArgument(destArgNum));
-        ++destArgNum;
-      }
+    for (Value operand : needCopyOperands) {
+      // Insert a `CopyOp` for each converted memref-type operand.
+      rewriter.create<CopyOpTy>(returnOp.getLoc(), operand,
+                                entryBlock.getArgument(destArgNum));
+      ++destArgNum;
+    }
 
     // Insert the new target Return operation.
     rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, newOperands);
     return success();
   }
 };
-
-/// A helper type converter class for using inside Buffer Assignment operation
-/// conversion patterns. The default constructor keeps all the types intact
-/// except for the ranked-tensor types which is converted to memref types.
-class BufferAssignmentTypeConverter : public TypeConverter {
-public:
-  BufferAssignmentTypeConverter();
-};
-
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_BUFFERPLACEMENT_H

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index 9b5855dff0ce..c663eb6017e5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -21,7 +21,7 @@
 
 using namespace mlir;
 using ReturnOpConverter =
-    NoBufferOperandsReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
+    BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
                                       linalg::CopyOp>;
 
 namespace {

diff  --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp
index cd0641c1ac32..60f49d4e305c 100644
--- a/mlir/lib/Transforms/BufferPlacement.cpp
+++ b/mlir/lib/Transforms/BufferPlacement.cpp
@@ -389,7 +389,13 @@ struct BufferPlacementPass
 
       // If there is an existing dealloc, move it to the right place.
       Operation *nextOp = positions.getDeallocPosition()->getNextNode();
-      assert(nextOp && "Invalid Dealloc operation position");
+      // If the Dealloc position is at the terminator operation of the block,
+      // then the value should escape from a deallocation.
+      if (!nextOp) {
+        assert(deallocs.size() == 0 &&
+               "There should be no dealloc for the returned buffer");
+        continue;
+      }
       if (deallocs.size()) {
         (*deallocs.begin())->moveBefore(nextOp);
       } else {
@@ -431,11 +437,6 @@ LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite(
     return failure();
   }
   auto funcType = funcOp.getType();
-  TypeRange resultTypes = funcType.getResults();
-  if (llvm::any_of(resultTypes,
-                   [](Type type) { return type.isa<MemRefType>(); }))
-    return funcOp.emitError("BufferAssignmentPlacer doesn't currently support "
-                            "functions which return memref typed values");
 
   // Convert function arguments using the provided TypeConverter.
   TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
@@ -443,17 +444,16 @@ LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite(
     conversion.addInputs(argType.index(),
                          converter->convertType(argType.value()));
 
-  // Adding a function argument for each function result which is going to be a
-  // memref type after type conversion.
+  // If a function result type is not a memref but it would be a memref after
+  // type conversion, a new argument should be appended to the function
+  // arguments list for this result. Otherwise, it remains unchanged as a
+  // function result.
   SmallVector<Type, 2> newResultTypes;
   newResultTypes.reserve(funcOp.getNumResults());
-  for (Type resType : resultTypes) {
+  for (Type resType : funcType.getResults()) {
     Type convertedType = converter->convertType(resType);
-
-    // If the result type is memref after the type conversion, a new argument
-    // should be appended to the function arguments list for this result.
-    // Otherwise, it remains unchanged as a function result.
-    if (convertedType.isa<MemRefType>())
+    if (BufferAssignmentTypeConverter::isConvertedMemref(convertedType,
+                                                         resType))
       conversion.addInputs(convertedType);
     else
       newResultTypes.push_back(convertedType);
@@ -482,6 +482,11 @@ BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() {
   });
 }
 
+/// Checks if `type` has been converted from non-memref type to memref.
+bool BufferAssignmentTypeConverter::isConvertedMemref(Type type, Type before) {
+  return type.isa<MemRefType>() && !before.isa<MemRefType>();
+}
+
 //===----------------------------------------------------------------------===//
 // BufferPlacementPass construction
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir
index ef7a2e328da5..8458154e4985 100644
--- a/mlir/test/Transforms/buffer-placement-preparation.mlir
+++ b/mlir/test/Transforms/buffer-placement-preparation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-buffer-placement-preparation -split-input-file -verify-diagnostics %s | FileCheck %s -dump-input-on-failure
+// RUN: mlir-opt -test-buffer-placement-preparation -split-input-file %s | FileCheck %s -dump-input-on-failure
 
 // CHECK-LABEL: func @func_signature_conversion
 func @func_signature_conversion(%arg0: tensor<4x8xf32>) {
@@ -8,12 +8,28 @@ func @func_signature_conversion(%arg0: tensor<4x8xf32>) {
 
 // -----
 
-// expected-error @below {{BufferAssignmentPlacer doesn't currently support functions which return memref typed values}}
-// expected-error @below {{failed to legalize operation 'func'}}
-func @memref_in_function_results(%arg0: tensor<4x8xf32>) -> (tensor<4x8xf32>, memref<5xf32>) {
-  %0 = alloc() : memref<5xf32>
-  return %arg0, %0 : tensor<4x8xf32>, memref<5xf32>
+// Only tensor typed function result should be converted to memref and move to the
+// function arguments list. The other memref function results remain as function
+// results.
+
+#map0 = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: func @memref_in_function_results
+func @memref_in_function_results(%arg0: tensor<5xf32>, %arg1: memref<10xf32>) -> (tensor<5xf32>, memref<10xf32>, memref<15xf32>) {
+  %0 = alloc() : memref<15xf32>
+  %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
+    ^bb0(%gen1_arg0: f32):
+      %tmp1 = exp %gen1_arg0 : f32
+      linalg.yield %tmp1 : f32
+    }: tensor<5xf32> -> tensor<5xf32>
+  return %1, %arg1, %0 : tensor<5xf32>, memref<10xf32>, memref<15xf32>
 }
+//      CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, %[[RESULT:.*]]: memref<5xf32>)
+// CHECK-SAME: (memref<10xf32>, memref<15xf32>)
+//      CHECK: %[[FIRST_ALLOC:.*]] = alloc()
+//      CHECK: %[[LINALG_ALLOC:.*]] = alloc()
+//      CHECK: linalg.copy(%[[LINALG_ALLOC]], %[[RESULT]])
+//      CHECK: return %[[ARG1]], %[[FIRST_ALLOC]]
 
 // -----
 

diff  --git a/mlir/test/Transforms/buffer-placement.mlir b/mlir/test/Transforms/buffer-placement.mlir
index afbf34ce43fb..4b401cc841af 100644
--- a/mlir/test/Transforms/buffer-placement.mlir
+++ b/mlir/test/Transforms/buffer-placement.mlir
@@ -457,3 +457,32 @@ func @nested_regions_and_cond_branch(%arg0: i1, %arg1: memref<2xf32>, %arg2: mem
 //      CHECK:  ^[[BB3:.*]]({{.*}}):
 //      CHECK:  linalg.copy
 // CHECK-NEXT:  dealloc %[[GENERIC1_ALLOC]]
+
+// -----
+
+// Test Case: buffer deallocation escaping
+// BufferPlacement Expected Behaviour: It must not dealloc %arg1 and %x
+// since they are operands of return operation and should escape from
+// deallocating. It should dealloc %y after linalg.copy.
+
+#map0 = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: func @memref_in_function_results
+func @memref_in_function_results(%arg0: memref<5xf32>, %arg1: memref<10xf32>, %arg2: memref<5xf32>) -> (memref<10xf32>, memref<15xf32>) {
+  %x = alloc() : memref<15xf32>
+  %y = alloc() : memref<5xf32>
+  linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %y {
+  ^bb0(%arg3: f32, %arg4: f32):
+    %2 = exp %arg3 : f32
+    linalg.yield %2 : f32
+  }: memref<5xf32>, memref<5xf32>
+  linalg.copy(%y, %arg2) : memref<5xf32>, memref<5xf32>
+  return %arg1, %x : memref<10xf32>, memref<15xf32>
+}
+// CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, %[[RESULT:.*]]: memref<5xf32>)
+// CHECK: %[[X:.*]] = alloc()
+// CHECK: %[[Y:.*]] = alloc()
+// CHECK: linalg.copy
+// CHECK: dealloc %[[Y]]
+// CHECK: return %[[ARG1]], %[[X]]
+

diff  --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index 2d781e64cdfa..6152a9b85435 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -23,7 +23,7 @@ using namespace mlir;
 namespace {
 /// This pass tests the computeAllocPosition helper method and two provided
 /// operation converters, FunctionAndBlockSignatureConverter and
-/// NoBufferOperandsReturnOpConverter. Furthermore, this pass converts linalg
+/// BufferAssignmentReturnOpConverter. Furthermore, this pass converts linalg
 /// operations on tensors to linalg operations on buffers to prepare them for
 /// the BufferPlacement pass that can be applied afterwards.
 struct TestBufferPlacementPreparationPass
@@ -41,16 +41,18 @@ struct TestBufferPlacementPreparationPass
     LogicalResult
     matchAndRewrite(linalg::GenericOp op, ArrayRef<Value> operands,
                     ConversionPatternRewriter &rewriter) const final {
-      auto loc = op.getLoc();
-      SmallVector<Value, 4> args(operands.begin(), operands.end());
+      Location loc = op.getLoc();
+      ResultRange results = op.getOperation()->getResults();
+      SmallVector<Value, 2> newArgs, newResults;
+      newArgs.reserve(operands.size() + results.size());
+      newArgs.append(operands.begin(), operands.end());
+      newResults.reserve(results.size());
 
       // Update all types to memref types.
-      auto results = op.getOperation()->getResults();
       for (auto result : results) {
-        auto type = result.getType().cast<ShapedType>();
-        if (!type)
-          op.emitOpError()
-              << "tensor to buffer conversion expects ranked results";
+        ShapedType type = result.getType().cast<ShapedType>();
+        assert(type && "Generic operations with non-shaped typed results are "
+                       "not currently supported.");
         if (!type.hasStaticShape())
           return rewriter.notifyMatchFailure(
               op, "dynamic shapes not currently supported");
@@ -62,27 +64,39 @@ struct TestBufferPlacementPreparationPass
         rewriter.restoreInsertionPoint(
             bufferAssignment->computeAllocPosition(result));
         auto alloc = rewriter.create<AllocOp>(loc, memrefType);
-        result.replaceAllUsesWith(alloc);
-        args.push_back(alloc);
+        newArgs.push_back(alloc);
+        newResults.push_back(alloc);
       }
 
       // Generate a new linalg operation that works on buffers.
       auto linalgOp = rewriter.create<linalg::GenericOp>(
-          loc, llvm::None, args, rewriter.getI64IntegerAttr(operands.size()),
+          loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()),
           rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(),
           op.iterator_types(), op.docAttr(), op.library_callAttr());
 
-      // Move regions from the old operation to the new one.
-      auto &region = linalgOp.region();
-      rewriter.inlineRegionBefore(op.region(), region, region.end());
-
-      // TODO: verify the internal memref-based linalg functionality.
-      auto &entryBlock = region.front();
-      for (auto result : results) {
-        auto type = result.getType().cast<ShapedType>();
-        entryBlock.addArgument(type.getElementType());
-      }
-      rewriter.eraseOp(op);
+      // Create a new block in the region of the new Generic Op.
+      Block &oldBlock = op.getRegion().front();
+      Region &newRegion = linalgOp.region();
+      Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
+                                             oldBlock.getArgumentTypes());
+
+      // Map the old block arguments to the new ones.
+      BlockAndValueMapping mapping;
+      mapping.map(oldBlock.getArguments(), newBlock->getArguments());
+
+      // Add the result arguments to the new block.
+      for (auto result : newResults)
+        newBlock->addArgument(
+            result.getType().cast<ShapedType>().getElementType());
+
+      // Clone the body of the old block to the new block.
+      rewriter.setInsertionPointToEnd(newBlock);
+      for (auto &op : oldBlock.getOperations())
+        rewriter.clone(op, mapping);
+
+      // Replace the results of the old Generic Op with the results of the new
+      // one.
+      rewriter.replaceOp(op, newResults);
       return success();
     }
   };
@@ -94,34 +108,33 @@ struct TestBufferPlacementPreparationPass
     patterns->insert<
                    FunctionAndBlockSignatureConverter,
                    GenericOpConverter,
-                   NoBufferOperandsReturnOpConverter<
+                   BufferAssignmentReturnOpConverter<
                       ReturnOp, ReturnOp, linalg::CopyOp>
     >(context, placer, converter);
     // clang-format on
   }
 
   void runOnOperation() override {
-    auto &context = getContext();
+    MLIRContext &context = getContext();
     ConversionTarget target(context);
     BufferAssignmentTypeConverter converter;
+
+    // Mark all Standard operations legal.
     target.addLegalDialect<StandardOpsDialect>();
 
-    // Make all linalg operations illegal as long as they work on tensors.
+    // Mark all Linalg operations illegal as long as they work on tensors.
+    auto isIllegalType = [&](Type type) { return !converter.isLegal(type); };
+    auto isLegalOperation = [&](Operation *op) {
+      return llvm::none_of(op->getOperandTypes(), isIllegalType) &&
+             llvm::none_of(op->getResultTypes(), isIllegalType);
+    };
     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(
         Optional<ConversionTarget::DynamicLegalityCallbackFn>(
-            [&](Operation *op) {
-              auto isIllegalType = [&](Type type) {
-                return !converter.isLegal(type);
-              };
-              return llvm::none_of(op->getOperandTypes(), isIllegalType) &&
-                     llvm::none_of(op->getResultTypes(), isIllegalType);
-            }));
-
-    // Mark std.ReturnOp illegal as long as an operand is tensor or buffer.
+            isLegalOperation));
+
+    // Mark Standard Return operations illegal as long as one operand is tensor.
     target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
-      return llvm::none_of(returnOp.getOperandTypes(), [&](Type type) {
-        return type.isa<MemRefType>() || !converter.isLegal(type);
-      });
+      return llvm::none_of(returnOp.getOperandTypes(), isIllegalType);
     });
 
     // Mark the function whose arguments are in tensor-type illegal.
@@ -130,16 +143,14 @@ struct TestBufferPlacementPreparationPass
     });
 
     // Walk over all the functions to apply buffer assignment.
-    getOperation().walk([&](FuncOp function) {
+    getOperation().walk([&](FuncOp function) -> WalkResult {
       OwningRewritePatternList patterns;
       BufferAssignmentPlacer placer(function);
       populateTensorLinalgToBufferLinalgConversionPattern(
           &context, &placer, &converter, &patterns);
 
       // Applying full conversion
-      return failed(applyFullConversion(function, target, patterns, &converter))
-                 ? WalkResult::interrupt()
-                 : WalkResult::advance();
+      return applyFullConversion(function, target, patterns, &converter);
     });
   };
 };


        


More information about the Mlir-commits mailing list