[llvm-branch-commits] [mlir] ced0b8e - [MLIR][BufferPlacement] Support functions that return Memref typed results
Ehsan Toosi via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri May 29 01:19:12 PDT 2020
Author: Ehsan Toosi
Date: 2020-05-29T10:13:02+02:00
New Revision: ced0b8eb2b443f40da34f3fff5aa64b555fa95fe
URL: https://github.com/llvm/llvm-project/commit/ced0b8eb2b443f40da34f3fff5aa64b555fa95fe
DIFF: https://github.com/llvm/llvm-project/commit/ced0b8eb2b443f40da34f3fff5aa64b555fa95fe.diff
LOG: [MLIR][BufferPlacement] Support functions that return Memref typed results
Buffer placement can now operates on functions that return buffers. These
buffers escape from the deallocation phase of buffer placement.
Differential Revision: https://reviews.llvm.org/D80696
Added:
Modified:
mlir/include/mlir/Transforms/BufferPlacement.h
mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
mlir/lib/Transforms/BufferPlacement.cpp
mlir/test/Transforms/buffer-placement-preparation.mlir
mlir/test/Transforms/buffer-placement.mlir
mlir/test/lib/Transforms/TestBufferPlacement.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h
index 030b87599d06..10949160fcbd 100644
--- a/mlir/include/mlir/Transforms/BufferPlacement.h
+++ b/mlir/include/mlir/Transforms/BufferPlacement.h
@@ -76,11 +76,23 @@ class BufferAssignmentOpConversionPattern
TypeConverter *converter;
};
-/// Converts the signature of the function using the type converter.
-/// It adds an extra argument for each illegally-typed function
-/// result to the function arguments. `BufferAssignmentTypeConverter`
-/// is a helper `TypeConverter` for this purpose. All the non-shaped types
-/// of the input function will be converted to memref.
+/// A helper type converter class for using inside Buffer Assignment operation
+/// conversion patterns. The default constructor keeps all the types intact
+/// except for the ranked-tensor types which is converted to memref types.
+class BufferAssignmentTypeConverter : public TypeConverter {
+public:
+ BufferAssignmentTypeConverter();
+
+ /// A helper function to check if `type` has been converted from non-memref
+ /// type to memref.
+ static bool isConvertedMemref(Type type, Type before);
+};
+
+/// Converts the signature of the function using the type converter. It adds an
+/// extra argument for each function result type which is going to be a memref
+/// type after type conversion. The other function result types remain
+/// unchanged. `BufferAssignmentTypeConverter` is a helper `TypeConverter` for
+/// this purpose.
class FunctionAndBlockSignatureConverter
: public BufferAssignmentOpConversionPattern<FuncOp> {
public:
@@ -93,12 +105,14 @@ class FunctionAndBlockSignatureConverter
ConversionPatternRewriter &rewriter) const final;
};
-/// Converts the source `ReturnOp` to target `ReturnOp`, removes all
-/// the buffer operands from the operands list, and inserts `CopyOp`s
-/// for all buffer operands instead.
+/// Rewrites the `ReturnOp` to conform with the changed function signature.
+/// Operands that correspond to return values that have been rewritten from
+/// tensor results to memref arguments are dropped. In their place, a
+/// corresponding copy operation from the operand to the new function argument
+/// is inserted.
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
typename CopyOpTy>
-class NoBufferOperandsReturnOpConverter
+class BufferAssignmentReturnOpConverter
: public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> {
public:
using BufferAssignmentOpConversionPattern<
@@ -108,50 +122,41 @@ class NoBufferOperandsReturnOpConverter
LogicalResult
matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
+ // Split the operands by their kinds whether they are converted memref or
+ // not.
+ SmallVector<Value, 2> needCopyOperands, newOperands;
+ unsigned operandsSize = operands.size();
+ needCopyOperands.reserve(operandsSize);
+ newOperands.reserve(operandsSize);
+ for (auto operand : llvm::enumerate(operands))
+ if (BufferAssignmentTypeConverter::isConvertedMemref(
+ operand.value().getType(),
+ returnOp.getOperand(operand.index()).getType()))
+ needCopyOperands.push_back(operand.value());
+ else
+ newOperands.push_back(operand.value());
+
Block &entryBlock = returnOp.getParentRegion()->front();
unsigned numFuncArgs = entryBlock.getNumArguments();
- Location loc = returnOp.getLoc();
-
- // The target `ReturnOp` should not contain any memref operands.
- SmallVector<Value, 2> newOperands(operands.begin(), operands.end());
- llvm::erase_if(newOperands, [](Value operand) {
- return operand.getType().isa<MemRefType>();
- });
// Find the index of the first destination buffer.
- unsigned numBufferOperands = operands.size() - newOperands.size();
- unsigned destArgNum = numFuncArgs - numBufferOperands;
-
+ assert(needCopyOperands.size() <= numFuncArgs &&
+ "The number of operands of return operation is more than the "
+ "number of function arguments.");
+ unsigned destArgNum = numFuncArgs - needCopyOperands.size();
rewriter.setInsertionPoint(returnOp);
- // Find the corresponding destination buffer for each memref operand.
- for (Value operand : operands)
- if (operand.getType().isa<MemRefType>()) {
- assert(destArgNum < numFuncArgs &&
- "The number of operands of return operation is more than the "
- "number of function argument.");
-
- // For each memref type operand of the source `ReturnOp`, a new `CopyOp`
- // is inserted that copies the buffer content from the operand to the
- // target.
- rewriter.create<CopyOpTy>(loc, operand,
- entryBlock.getArgument(destArgNum));
- ++destArgNum;
- }
+ for (Value operand : needCopyOperands) {
+ // Insert a `CopyOp` for each converted memref-type operand.
+ rewriter.create<CopyOpTy>(returnOp.getLoc(), operand,
+ entryBlock.getArgument(destArgNum));
+ ++destArgNum;
+ }
// Insert the new target Return operation.
rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, newOperands);
return success();
}
};
-
-/// A helper type converter class for using inside Buffer Assignment operation
-/// conversion patterns. The default constructor keeps all the types intact
-/// except for the ranked-tensor types which is converted to memref types.
-class BufferAssignmentTypeConverter : public TypeConverter {
-public:
- BufferAssignmentTypeConverter();
-};
-
} // end namespace mlir
#endif // MLIR_TRANSFORMS_BUFFERPLACEMENT_H
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index 9b5855dff0ce..c663eb6017e5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -21,7 +21,7 @@
using namespace mlir;
using ReturnOpConverter =
- NoBufferOperandsReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
+ BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
linalg::CopyOp>;
namespace {
diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp
index cd0641c1ac32..60f49d4e305c 100644
--- a/mlir/lib/Transforms/BufferPlacement.cpp
+++ b/mlir/lib/Transforms/BufferPlacement.cpp
@@ -389,7 +389,13 @@ struct BufferPlacementPass
// If there is an existing dealloc, move it to the right place.
Operation *nextOp = positions.getDeallocPosition()->getNextNode();
- assert(nextOp && "Invalid Dealloc operation position");
+ // If the Dealloc position is at the terminator operation of the block,
+ // then the value should escape from a deallocation.
+ if (!nextOp) {
+ assert(deallocs.size() == 0 &&
+ "There should be no dealloc for the returned buffer");
+ continue;
+ }
if (deallocs.size()) {
(*deallocs.begin())->moveBefore(nextOp);
} else {
@@ -431,11 +437,6 @@ LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite(
return failure();
}
auto funcType = funcOp.getType();
- TypeRange resultTypes = funcType.getResults();
- if (llvm::any_of(resultTypes,
- [](Type type) { return type.isa<MemRefType>(); }))
- return funcOp.emitError("BufferAssignmentPlacer doesn't currently support "
- "functions which return memref typed values");
// Convert function arguments using the provided TypeConverter.
TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
@@ -443,17 +444,16 @@ LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite(
conversion.addInputs(argType.index(),
converter->convertType(argType.value()));
- // Adding a function argument for each function result which is going to be a
- // memref type after type conversion.
+ // If a function result type is not a memref but it would be a memref after
+ // type conversion, a new argument should be appended to the function
+ // arguments list for this result. Otherwise, it remains unchanged as a
+ // function result.
SmallVector<Type, 2> newResultTypes;
newResultTypes.reserve(funcOp.getNumResults());
- for (Type resType : resultTypes) {
+ for (Type resType : funcType.getResults()) {
Type convertedType = converter->convertType(resType);
-
- // If the result type is memref after the type conversion, a new argument
- // should be appended to the function arguments list for this result.
- // Otherwise, it remains unchanged as a function result.
- if (convertedType.isa<MemRefType>())
+ if (BufferAssignmentTypeConverter::isConvertedMemref(convertedType,
+ resType))
conversion.addInputs(convertedType);
else
newResultTypes.push_back(convertedType);
@@ -482,6 +482,11 @@ BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() {
});
}
+/// Checks if `type` has been converted from non-memref type to memref.
+bool BufferAssignmentTypeConverter::isConvertedMemref(Type type, Type before) {
+ return type.isa<MemRefType>() && !before.isa<MemRefType>();
+}
+
//===----------------------------------------------------------------------===//
// BufferPlacementPass construction
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir
index ef7a2e328da5..8458154e4985 100644
--- a/mlir/test/Transforms/buffer-placement-preparation.mlir
+++ b/mlir/test/Transforms/buffer-placement-preparation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-buffer-placement-preparation -split-input-file -verify-diagnostics %s | FileCheck %s -dump-input-on-failure
+// RUN: mlir-opt -test-buffer-placement-preparation -split-input-file %s | FileCheck %s -dump-input-on-failure
// CHECK-LABEL: func @func_signature_conversion
func @func_signature_conversion(%arg0: tensor<4x8xf32>) {
@@ -8,12 +8,28 @@ func @func_signature_conversion(%arg0: tensor<4x8xf32>) {
// -----
-// expected-error @below {{BufferAssignmentPlacer doesn't currently support functions which return memref typed values}}
-// expected-error @below {{failed to legalize operation 'func'}}
-func @memref_in_function_results(%arg0: tensor<4x8xf32>) -> (tensor<4x8xf32>, memref<5xf32>) {
- %0 = alloc() : memref<5xf32>
- return %arg0, %0 : tensor<4x8xf32>, memref<5xf32>
+// Only tensor typed function result should be converted to memref and move to the
+// function arguments list. The other memref function results remain as function
+// results.
+
+#map0 = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: func @memref_in_function_results
+func @memref_in_function_results(%arg0: tensor<5xf32>, %arg1: memref<10xf32>) -> (tensor<5xf32>, memref<10xf32>, memref<15xf32>) {
+ %0 = alloc() : memref<15xf32>
+ %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
+ ^bb0(%gen1_arg0: f32):
+ %tmp1 = exp %gen1_arg0 : f32
+ linalg.yield %tmp1 : f32
+ }: tensor<5xf32> -> tensor<5xf32>
+ return %1, %arg1, %0 : tensor<5xf32>, memref<10xf32>, memref<15xf32>
}
+// CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, %[[RESULT:.*]]: memref<5xf32>)
+// CHECK-SAME: (memref<10xf32>, memref<15xf32>)
+// CHECK: %[[FIRST_ALLOC:.*]] = alloc()
+// CHECK: %[[LINALG_ALLOC:.*]] = alloc()
+// CHECK: linalg.copy(%[[LINALG_ALLOC]], %[[RESULT]])
+// CHECK: return %[[ARG1]], %[[FIRST_ALLOC]]
// -----
diff --git a/mlir/test/Transforms/buffer-placement.mlir b/mlir/test/Transforms/buffer-placement.mlir
index afbf34ce43fb..4b401cc841af 100644
--- a/mlir/test/Transforms/buffer-placement.mlir
+++ b/mlir/test/Transforms/buffer-placement.mlir
@@ -457,3 +457,32 @@ func @nested_regions_and_cond_branch(%arg0: i1, %arg1: memref<2xf32>, %arg2: mem
// CHECK: ^[[BB3:.*]]({{.*}}):
// CHECK: linalg.copy
// CHECK-NEXT: dealloc %[[GENERIC1_ALLOC]]
+
+// -----
+
+// Test Case: buffer deallocation escaping
+// BufferPlacement Expected Behaviour: It must not dealloc %arg1 and %x
+// since they are operands of return operation and should escape from
+// deallocating. It should dealloc %y after linalg.copy.
+
+#map0 = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: func @memref_in_function_results
+func @memref_in_function_results(%arg0: memref<5xf32>, %arg1: memref<10xf32>, %arg2: memref<5xf32>) -> (memref<10xf32>, memref<15xf32>) {
+ %x = alloc() : memref<15xf32>
+ %y = alloc() : memref<5xf32>
+ linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %y {
+ ^bb0(%arg3: f32, %arg4: f32):
+ %2 = exp %arg3 : f32
+ linalg.yield %2 : f32
+ }: memref<5xf32>, memref<5xf32>
+ linalg.copy(%y, %arg2) : memref<5xf32>, memref<5xf32>
+ return %arg1, %x : memref<10xf32>, memref<15xf32>
+}
+// CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, %[[RESULT:.*]]: memref<5xf32>)
+// CHECK: %[[X:.*]] = alloc()
+// CHECK: %[[Y:.*]] = alloc()
+// CHECK: linalg.copy
+// CHECK: dealloc %[[Y]]
+// CHECK: return %[[ARG1]], %[[X]]
+
diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index 2d781e64cdfa..6152a9b85435 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -23,7 +23,7 @@ using namespace mlir;
namespace {
/// This pass tests the computeAllocPosition helper method and two provided
/// operation converters, FunctionAndBlockSignatureConverter and
-/// NoBufferOperandsReturnOpConverter. Furthermore, this pass converts linalg
+/// BufferAssignmentReturnOpConverter. Furthermore, this pass converts linalg
/// operations on tensors to linalg operations on buffers to prepare them for
/// the BufferPlacement pass that can be applied afterwards.
struct TestBufferPlacementPreparationPass
@@ -41,16 +41,18 @@ struct TestBufferPlacementPreparationPass
LogicalResult
matchAndRewrite(linalg::GenericOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- auto loc = op.getLoc();
- SmallVector<Value, 4> args(operands.begin(), operands.end());
+ Location loc = op.getLoc();
+ ResultRange results = op.getOperation()->getResults();
+ SmallVector<Value, 2> newArgs, newResults;
+ newArgs.reserve(operands.size() + results.size());
+ newArgs.append(operands.begin(), operands.end());
+ newResults.reserve(results.size());
// Update all types to memref types.
- auto results = op.getOperation()->getResults();
for (auto result : results) {
- auto type = result.getType().cast<ShapedType>();
- if (!type)
- op.emitOpError()
- << "tensor to buffer conversion expects ranked results";
+ ShapedType type = result.getType().cast<ShapedType>();
+ assert(type && "Generic operations with non-shaped typed results are "
+ "not currently supported.");
if (!type.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "dynamic shapes not currently supported");
@@ -62,27 +64,39 @@ struct TestBufferPlacementPreparationPass
rewriter.restoreInsertionPoint(
bufferAssignment->computeAllocPosition(result));
auto alloc = rewriter.create<AllocOp>(loc, memrefType);
- result.replaceAllUsesWith(alloc);
- args.push_back(alloc);
+ newArgs.push_back(alloc);
+ newResults.push_back(alloc);
}
// Generate a new linalg operation that works on buffers.
auto linalgOp = rewriter.create<linalg::GenericOp>(
- loc, llvm::None, args, rewriter.getI64IntegerAttr(operands.size()),
+ loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()),
rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(),
op.iterator_types(), op.docAttr(), op.library_callAttr());
- // Move regions from the old operation to the new one.
- auto ®ion = linalgOp.region();
- rewriter.inlineRegionBefore(op.region(), region, region.end());
-
- // TODO: verify the internal memref-based linalg functionality.
- auto &entryBlock = region.front();
- for (auto result : results) {
- auto type = result.getType().cast<ShapedType>();
- entryBlock.addArgument(type.getElementType());
- }
- rewriter.eraseOp(op);
+ // Create a new block in the region of the new Generic Op.
+ Block &oldBlock = op.getRegion().front();
+ Region &newRegion = linalgOp.region();
+ Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
+ oldBlock.getArgumentTypes());
+
+ // Map the old block arguments to the new ones.
+ BlockAndValueMapping mapping;
+ mapping.map(oldBlock.getArguments(), newBlock->getArguments());
+
+ // Add the result arguments to the new block.
+ for (auto result : newResults)
+ newBlock->addArgument(
+ result.getType().cast<ShapedType>().getElementType());
+
+ // Clone the body of the old block to the new block.
+ rewriter.setInsertionPointToEnd(newBlock);
+ for (auto &op : oldBlock.getOperations())
+ rewriter.clone(op, mapping);
+
+ // Replace the results of the old Generic Op with the results of the new
+ // one.
+ rewriter.replaceOp(op, newResults);
return success();
}
};
@@ -94,34 +108,33 @@ struct TestBufferPlacementPreparationPass
patterns->insert<
FunctionAndBlockSignatureConverter,
GenericOpConverter,
- NoBufferOperandsReturnOpConverter<
+ BufferAssignmentReturnOpConverter<
ReturnOp, ReturnOp, linalg::CopyOp>
>(context, placer, converter);
// clang-format on
}
void runOnOperation() override {
- auto &context = getContext();
+ MLIRContext &context = getContext();
ConversionTarget target(context);
BufferAssignmentTypeConverter converter;
+
+ // Mark all Standard operations legal.
target.addLegalDialect<StandardOpsDialect>();
- // Make all linalg operations illegal as long as they work on tensors.
+ // Mark all Linalg operations illegal as long as they work on tensors.
+ auto isIllegalType = [&](Type type) { return !converter.isLegal(type); };
+ auto isLegalOperation = [&](Operation *op) {
+ return llvm::none_of(op->getOperandTypes(), isIllegalType) &&
+ llvm::none_of(op->getResultTypes(), isIllegalType);
+ };
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(
Optional<ConversionTarget::DynamicLegalityCallbackFn>(
- [&](Operation *op) {
- auto isIllegalType = [&](Type type) {
- return !converter.isLegal(type);
- };
- return llvm::none_of(op->getOperandTypes(), isIllegalType) &&
- llvm::none_of(op->getResultTypes(), isIllegalType);
- }));
-
- // Mark std.ReturnOp illegal as long as an operand is tensor or buffer.
+ isLegalOperation));
+
+ // Mark Standard Return operations illegal as long as one operand is tensor.
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
- return llvm::none_of(returnOp.getOperandTypes(), [&](Type type) {
- return type.isa<MemRefType>() || !converter.isLegal(type);
- });
+ return llvm::none_of(returnOp.getOperandTypes(), isIllegalType);
});
// Mark the function whose arguments are in tensor-type illegal.
@@ -130,16 +143,14 @@ struct TestBufferPlacementPreparationPass
});
// Walk over all the functions to apply buffer assignment.
- getOperation().walk([&](FuncOp function) {
+ getOperation().walk([&](FuncOp function) -> WalkResult {
OwningRewritePatternList patterns;
BufferAssignmentPlacer placer(function);
populateTensorLinalgToBufferLinalgConversionPattern(
&context, &placer, &converter, &patterns);
// Applying full conversion
- return failed(applyFullConversion(function, target, patterns, &converter))
- ? WalkResult::interrupt()
- : WalkResult::advance();
+ return applyFullConversion(function, target, patterns, &converter);
});
};
};
More information about the llvm-branch-commits
mailing list