[Mlir-commits] [mlir] 94f5d24 - [mlir] Extend BufferAssignmentTypeConverter with result conversion callbacks
Ehsan Toosi
llvmlistbot at llvm.org
Wed Sep 2 04:32:37 PDT 2020
Author: Ehsan Toosi
Date: 2020-09-02T13:26:55+02:00
New Revision: 94f5d248772ba0f1f9c8b0746fe75a5d246c5540
URL: https://github.com/llvm/llvm-project/commit/94f5d248772ba0f1f9c8b0746fe75a5d246c5540
DIFF: https://github.com/llvm/llvm-project/commit/94f5d248772ba0f1f9c8b0746fe75a5d246c5540.diff
LOG: [mlir] Extend BufferAssignmentTypeConverter with result conversion callbacks
In this PR, the users of BufferPlacement can configure
BufferAssginmentTypeConverter. These new configurations would give the user more
freedom in the process of converting function signature, and return and call
operation conversions.
These are the new features:
- Accepting callback functions for decomposing types (i.e. 1 to N type
conversion such as unpacking tuple types).
- Defining ResultConversionKind for specifying whether a function result
with a certain type should be appended to the function arguments list or
should be kept as function result. (Usage:
converter.setResultConversionKind<MemRefType>(AppendToArgumentList))
- Accepting callback functions for composing or decomposing values (i.e. N
to 1 and 1 to N value conversion).
Differential Revision: https://reviews.llvm.org/D85133
Added:
Modified:
mlir/include/mlir/Transforms/BufferPlacement.h
mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
mlir/lib/Transforms/BufferPlacement.cpp
mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
mlir/test/Transforms/buffer-placement-preparation.mlir
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Transforms/TestBufferPlacement.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h
index f8559a9dd939..8fc254e6be1e 100644
--- a/mlir/include/mlir/Transforms/BufferPlacement.h
+++ b/mlir/include/mlir/Transforms/BufferPlacement.h
@@ -52,6 +52,111 @@ class BufferAssignmentPlacer {
Operation *operation;
};
+/// A helper type converter class for using inside Buffer Assignment operation
+/// conversion patterns. The default constructor keeps all the types intact
+/// except for the ranked-tensor types which is converted to memref types.
+class BufferAssignmentTypeConverter : public TypeConverter {
+public:
+ /// This enum is for showing how buffer placement operation converters should
+ /// conduct with certain result type after type conversion. This value can be
+ /// set/get for each specific type using setResultConversionKind or
+ /// getResultConversionKind.
+ enum ResultConversionKind { AppendToArgumentsList, KeepAsFunctionResult };
+
+ BufferAssignmentTypeConverter();
+
+ /// This method tries to decompose a value of a certain type using provided
+ /// decompose callback functions. If it is unable to do so, the original value
+ /// is returned.
+ void tryDecomposeValue(OpBuilder &, Location, Type, Value,
+ SmallVectorImpl<Value> &);
+
+ /// This method tries to decompose a type using provided decompose callback
+ /// functions. If it is unable to do so, the original type is returned.
+ void tryDecomposeType(Type, SmallVectorImpl<Type> &);
+
+ /// This method registers a callback function that will be called to decompose
+ /// a value of a certain type into several values.
+ template <typename FnT,
+ typename T = typename llvm::function_traits<FnT>::template arg_t<2>>
+ void addDecomposeValueConversion(FnT &&callback) {
+ decomposeValueConversions.emplace_back(
+ wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
+ }
+
+ /// This method registers a callback function that will be called to decompose
+ /// a type into several types.
+ template <typename FnT,
+ typename T = typename llvm::function_traits<FnT>::template arg_t<0>>
+ void addDecomposeTypeConversion(FnT &&callback) {
+ auto wrapper =
+ wrapDecomposeTypeConversionCallback<T>(std::forward<FnT>(callback));
+ decomposeTypeConversions.emplace_back(wrapper);
+ addConversion(std::forward<FnT>(callback));
+ }
+
+ /// This method returns ResultConversionKind for the mapping from `origin`
+ /// type to `input` type.
+ ResultConversionKind getResultConversionKind(Type origin, Type input);
+
+ /// This method registers ResultConversionKind for the mapping from type 'T'
+ /// to type 'U'.
+ template <typename T, typename U>
+ void setResultConversionKind(ResultConversionKind kind) {
+ assert((kind != AppendToArgumentsList ||
+ llvm::is_one_of<U, MemRefType, UnrankedMemRefType>::value) &&
+ "Only the memref typed values can be set to be appended to the "
+ "function argument list at the moment");
+ resultTypeConversions.emplace_back(
+ [&](Type origin, Type input) -> Optional<ResultConversionKind> {
+ if (origin.template isa<T>() && input.template isa<U>())
+ return kind;
+ return llvm::None;
+ });
+ }
+
+private:
+ using DecomposeValueConversionCallFn = std::function<Optional<LogicalResult>(
+ OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>;
+
+ using DecomposeTypeConversionCallFn =
+ std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
+
+ using ResultConversionKindFn =
+ std::function<Optional<ResultConversionKind>(Type, Type)>;
+
+ /// Generate a wrapper for the given decompose value conversion callback.
+ template <typename T, typename FnT>
+ DecomposeValueConversionCallFn
+ wrapDecomposeValueConversionCallback(FnT &&callback) {
+ return [callback = std::forward<FnT>(callback)](
+ OpBuilder &builder, Location loc, Type type, Value value,
+ SmallVectorImpl<Value> &newValues) -> Optional<LogicalResult> {
+ if (T derivedType = type.dyn_cast<T>())
+ return callback(builder, loc, derivedType, value, newValues);
+ return llvm::None;
+ };
+ }
+
+ /// Generate a wrapper for the given decompose type conversion callback.
+ template <typename T, typename FnT>
+ DecomposeTypeConversionCallFn
+ wrapDecomposeTypeConversionCallback(FnT &&callback) {
+ return [callback = std::forward<FnT>(callback)](
+ Type type,
+ SmallVectorImpl<Type> &results) -> Optional<LogicalResult> {
+ T derivedType = type.dyn_cast<T>();
+ if (!derivedType)
+ return llvm::None;
+ return callback(derivedType, results);
+ };
+ }
+
+ SmallVector<ResultConversionKindFn, 2> resultTypeConversions;
+ SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions;
+ SmallVector<DecomposeTypeConversionCallFn, 2> decomposeTypeConversions;
+};
+
/// Helper conversion pattern that encapsulates a BufferAssignmentPlacer
/// instance. Sample usage:
/// class CustomConversionPattern : public
@@ -68,43 +173,22 @@ class BufferAssignmentOpConversionPattern
public:
explicit BufferAssignmentOpConversionPattern(
MLIRContext *context, BufferAssignmentPlacer *bufferAssignment = nullptr,
- TypeConverter *converter = nullptr, PatternBenefit benefit = 1)
+ BufferAssignmentTypeConverter *converter = nullptr,
+ PatternBenefit benefit = 1)
: OpConversionPattern<SourceOp>(context, benefit),
- bufferAssignment(bufferAssignment), converter(converter) {}
+ bufferAssignment(bufferAssignment), converter(converter) {
+ assert(converter && "The type converter has not been defined");
+ }
protected:
BufferAssignmentPlacer *bufferAssignment;
- TypeConverter *converter;
-};
-
-/// A helper type converter class for using inside Buffer Assignment operation
-/// conversion patterns. The default constructor keeps all the types intact
-/// except for the ranked-tensor types which is converted to memref types.
-class BufferAssignmentTypeConverter : public TypeConverter {
-public:
- BufferAssignmentTypeConverter();
-
- /// A helper function to check if `type` has been converted from non-memref
- /// type to memref.
- static bool isConvertedMemref(Type type, Type before);
+ BufferAssignmentTypeConverter *converter;
};
-namespace detail {
-
-/// Converts the signature of the function based on whether the function is
-/// allowed to return memref typed results or not using
-/// `allowMemrefFunctionResults` parameter. If this option is false, then it
-/// adds an extra function argument as an output buffer for each function result
-/// which is going to be a memref type only after type conversion. The
-/// other function result types remain unchanged. If
-/// `allowMemrefFunctionResults` is true, the types are converted in place.
-/// Any changes in function signature need to be applied
-/// to return and caller operations. `BufferAssignmentReturnOpConverter` and
-/// `BufferAssignmentCallOpConverter` are two helper function that match the
-/// return and caller operation with the new function signature. Furthermore,
-/// `BufferAssignmentTypeConverter` is a helper `TypeConverter` for converting
-/// tensor typed values to memref typed ones.
-template <bool allowMemrefFunctionResults>
+/// Converts the signature of the function using BufferAssignmentTypeConverter.
+/// Each result type of the function is kept as a function result or appended to
+/// the function arguments list based on ResultConversionKind for the converted
+/// result type.
class BufferAssignmentFuncOpConverter
: public BufferAssignmentOpConversionPattern<FuncOp> {
public:
@@ -112,58 +196,16 @@ class BufferAssignmentFuncOpConverter
FuncOp>::BufferAssignmentOpConversionPattern;
/// Performs the actual signature rewriting step.
- LogicalResult
- matchAndRewrite(mlir::FuncOp funcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- if (!converter)
- return funcOp.emitError("The type converter has not been defined for "
- "BufferAssignmentFuncOpConverter");
- auto funcType = funcOp.getType();
-
- // Convert function arguments using the provided TypeConverter.
- TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
- for (auto argType : llvm::enumerate(funcType.getInputs()))
- conversion.addInputs(argType.index(),
- converter->convertType(argType.value()));
-
- // If allowMemrefFunctionResults is false and a function result type is not
- // a memref but it would be a memref after type conversion, a new argument
- // should be appended to the function arguments list for this result.
- // Otherwise, it remains unchanged as a function result.
- SmallVector<Type, 2> newResultTypes;
- newResultTypes.reserve(funcOp.getNumResults());
- for (Type resType : funcType.getResults()) {
- Type convertedType = converter->convertType(resType);
- if (!allowMemrefFunctionResults &&
- BufferAssignmentTypeConverter::isConvertedMemref(convertedType,
- resType))
- conversion.addInputs(convertedType);
- else
- newResultTypes.push_back(convertedType);
- }
- if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter,
- &conversion)))
- return failure();
-
- // Update the signature of the function.
- rewriter.updateRootInPlace(funcOp, [&] {
- funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
- newResultTypes));
- });
- return success();
- }
+ LogicalResult matchAndRewrite(mlir::FuncOp, ArrayRef<Value>,
+ ConversionPatternRewriter &) const;
};
/// Rewrites the `ReturnOp` to conform with the changed function signature.
-/// if allowMemrefFunctionResults is false, operands that correspond to return
-/// values and have been rewritten from illegal typed results to memref
-/// arguments are dropped. In their place, a corresponding copy operation from
-/// the operand to the output function argument is inserted. Otherwise, the
-/// memref typed operands are returned.
-/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
-/// allowMemrefFunctionResults must be set/unset for both.
+/// Operands that correspond to return values and their types have been set to
+/// AppendToArgumentsList are dropped. In their place, a corresponding copy
+/// operation from the operand to the target function argument is inserted.
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
- typename CopyOpTy, bool allowMemrefFunctionResults>
+ typename CopyOpTy>
class BufferAssignmentReturnOpConverter
: public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> {
public:
@@ -174,44 +216,48 @@ class BufferAssignmentReturnOpConverter
LogicalResult
matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- // If the memref typed results can be returned as function results, the new
- // `ReturnOp` should only return the type converted operands.
- if (allowMemrefFunctionResults) {
- rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, operands);
- return success();
+ Location loc = returnOp.getLoc();
+
+ // Split the operands depending on whether they need a copy operation or
+ // they remain as operands of the return operation. If an operand is
+ // decomposable and a decompose callback function has been provided by the
+ // user, it will be unpacked.
+ SmallVector<Value, 2> newOperands, needCopyOperands;
+ OpBuilder builder(returnOp);
+ for (auto operand : llvm::enumerate(operands)) {
+ SmallVector<Value, 2> values;
+ this->converter->tryDecomposeValue(
+ builder, loc, operand.value().getType(), operand.value(), values);
+ Type type = returnOp.getOperand(operand.index()).getType();
+ SmallVector<Type, 2> originTypes;
+ this->converter->tryDecomposeType(type, originTypes);
+ for (auto value : llvm::enumerate(values)) {
+ Type origin = originTypes[value.index()];
+ Type converted = value.value().getType();
+ auto kind = this->converter->getResultConversionKind(origin, converted);
+ if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult)
+ newOperands.push_back(value.value());
+ else
+ // kind = BufferAssignmentTypeConverter::AppendToArgumentsList
+ needCopyOperands.push_back(value.value());
+ }
}
- // Split the operands by their kinds whether they are converted memref or
- // not.
- SmallVector<Value, 2> needCopyOperands, newOperands;
- unsigned operandsSize = operands.size();
- needCopyOperands.reserve(operandsSize);
- newOperands.reserve(operandsSize);
- for (auto operand : llvm::enumerate(operands))
- if (BufferAssignmentTypeConverter::isConvertedMemref(
- operand.value().getType(),
- returnOp.getOperand(operand.index()).getType()))
- needCopyOperands.push_back(operand.value());
- else
- newOperands.push_back(operand.value());
-
+ // Insert Copy operations instead for the operands that have been removed
+ // from operand list and appended to the function arguments list.
Block &entryBlock = returnOp.getParentRegion()->front();
unsigned numFuncArgs = entryBlock.getNumArguments();
-
- // Find the index of the first destination buffer.
- assert(needCopyOperands.size() <= numFuncArgs &&
- "The number of operands of return operation is more than the "
- "number of function arguments.");
+ if (needCopyOperands.size() > numFuncArgs)
+ return returnOp.emitError(
+ "The number of operands that need Copy operations is more "
+ "than the number of target function arguments.");
unsigned destArgNum = numFuncArgs - needCopyOperands.size();
rewriter.setInsertionPoint(returnOp);
for (Value operand : needCopyOperands) {
- // Insert a `CopyOp` for each converted memref-type operand.
- rewriter.create<CopyOpTy>(returnOp.getLoc(), operand,
+ rewriter.create<CopyOpTy>(loc, operand,
entryBlock.getArgument(destArgNum));
++destArgNum;
}
-
- // Insert the new target Return operation.
rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, newOperands);
return success();
}
@@ -219,94 +265,32 @@ class BufferAssignmentReturnOpConverter
/// Rewrites the `CallOp` to match its operands and results with the signature
/// of the callee after rewriting the callee with
-/// BufferAssignmentFuncOpConverter. If allowMemrefFunctionResults is false, a
-/// buffer is allocated as an output buffer only for each memref typed result
-/// that has been rewritten. The new allocated buffer is passed through the
-/// operands list of the new `CallOp`.
-/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
-/// allowMemrefFunctionResults must be set/unset for both.
-template <bool allowMemrefFunctionResults>
+/// BufferAssignmentFuncOpConverter.
class BufferAssignmentCallOpConverter
: public BufferAssignmentOpConversionPattern<CallOp> {
public:
using BufferAssignmentOpConversionPattern<
CallOp>::BufferAssignmentOpConversionPattern;
- LogicalResult
- matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- if (!converter)
- return callOp.emitError("The type converter has not been defined for "
- "BufferAssignmentCallOpConverter");
- Location loc = callOp.getLoc();
-
- // If the memref typed results can be returned as function results, there is
- // no need to create output buffers. It is only required to convert the type
- // of operands and results in place for creating the new `CallOp`.
- if (allowMemrefFunctionResults) {
- SmallVector<Type, 2> resultTypes;
- resultTypes.reserve(callOp.getNumResults());
- for (Type type : callOp.getResultTypes())
- resultTypes.push_back(converter->convertType(type));
- rewriter.replaceOpWithNewOp<CallOp>(callOp, callOp.getCallee(),
- resultTypes, operands);
- return success();
- }
-
- SmallVector<Value, 2> newOperands, replacingValues;
- SmallVector<Type, 2> newResultTypes;
- unsigned numResults = callOp.getNumResults();
- newOperands.reserve(numResults + operands.size());
- newOperands.append(operands.begin(), operands.end());
- newResultTypes.reserve(numResults);
- replacingValues.reserve(numResults);
-
- // For each memref result of `CallOp` which has not been a memref before
- // the type conversion, a new buffer is allocated and passed to the operands
- // list of the new `CallOp`. Otherwise, it remains as a caller result.
- for (Value result : callOp.getResults()) {
- Type currType = result.getType();
- Type newType = converter->convertType(result.getType());
- if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) {
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.restoreInsertionPoint(bufferAssignment->computeAllocPosition(
- result.dyn_cast<OpResult>()));
- Value alloc =
- rewriter.create<AllocOp>(loc, newType.dyn_cast<MemRefType>());
- newOperands.push_back(alloc);
- replacingValues.push_back(alloc);
- } else {
- newResultTypes.push_back(currType);
-
- // No replacing is required.
- replacingValues.push_back(nullptr);
- }
- }
-
- // Creating the new `CallOp`.
- rewriter.create<CallOp>(loc, callOp.getCallee(), newResultTypes,
- newOperands);
-
- // Replacing the results of the old `CallOp`.
- rewriter.replaceOp(callOp, replacingValues);
- return success();
- }
+ /// Performs the actual rewriting step.
+ LogicalResult matchAndRewrite(CallOp, ArrayRef<Value>,
+ ConversionPatternRewriter &) const;
};
-} // end namespace detail
/// Populates `patterns` with the conversion patterns of buffer
/// assignment.
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
- typename CopyOpTy, bool allowMemrefFunctionResults>
+ typename CopyOpTy>
static void populateWithBufferAssignmentOpConversionPatterns(
MLIRContext *context, BufferAssignmentPlacer *placer,
- TypeConverter *converter, OwningRewritePatternList *patterns) {
+ BufferAssignmentTypeConverter *converter,
+ OwningRewritePatternList *patterns) {
// clang-format off
patterns->insert<
- detail::BufferAssignmentCallOpConverter<allowMemrefFunctionResults>,
- detail::BufferAssignmentFuncOpConverter<allowMemrefFunctionResults>,
- detail::BufferAssignmentReturnOpConverter
- <ReturnOpSourceTy, ReturnOpTargetTy, CopyOpTy, allowMemrefFunctionResults>
+ BufferAssignmentCallOpConverter,
+ BufferAssignmentFuncOpConverter,
+ BufferAssignmentReturnOpConverter
+ <ReturnOpSourceTy, ReturnOpTargetTy, CopyOpTy>
>(context, placer, converter);
// clang-format on
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index 04c1fbd5d565..89a01f9ca629 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -100,11 +100,11 @@ class GenericOpConverter
/// tensors to buffers.
static void populateConvertLinalgOnTensorsToBuffersPattern(
MLIRContext *context, BufferAssignmentPlacer *placer,
- TypeConverter *converter, OwningRewritePatternList *patterns) {
+ BufferAssignmentTypeConverter *converter,
+ OwningRewritePatternList *patterns) {
populateWithBufferAssignmentOpConversionPatterns<
- mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
- /*allowMemrefFunctionResults=*/false>(context, placer, converter,
- patterns);
+ mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer,
+ converter, patterns);
patterns->insert<GenericOpConverter>(context, placer, converter);
}
@@ -141,6 +141,9 @@ struct ConvertLinalgOnTensorsToBuffers
converter.isLegal(&funcOp.getBody());
});
+ converter.setResultConversionKind<RankedTensorType, MemRefType>(
+ BufferAssignmentTypeConverter::AppendToArgumentsList);
+
// Walk over all the functions to apply buffer assignment.
getOperation().walk([&](FuncOp function) -> WalkResult {
OwningRewritePatternList patterns;
diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp
index 201570a244ff..1ab3e7e2e48d 100644
--- a/mlir/lib/Transforms/BufferPlacement.cpp
+++ b/mlir/lib/Transforms/BufferPlacement.cpp
@@ -713,9 +713,223 @@ BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() {
});
}
-/// Checks if `type` has been converted from non-memref type to memref.
-bool BufferAssignmentTypeConverter::isConvertedMemref(Type type, Type before) {
- return type.isa<BaseMemRefType>() && !before.isa<BaseMemRefType>();
+/// This method tries to decompose a value of a certain type using provided
+/// decompose callback functions. If it is unable to do so, the original value
+/// is returned.
+void BufferAssignmentTypeConverter::tryDecomposeValue(
+ OpBuilder &builder, Location loc, Type type, Value value,
+ SmallVectorImpl<Value> &results) {
+ for (auto conversion : decomposeValueConversions)
+ if (conversion(builder, loc, type, value, results) != llvm::None)
+ return;
+ results.push_back(value);
+}
+
+/// This method tries to decompose a type using provided decompose callback
+/// functions. If it is unable to do so, the original type is returned.
+void BufferAssignmentTypeConverter::tryDecomposeType(
+ Type type, SmallVectorImpl<Type> &types) {
+ for (auto conversion : decomposeTypeConversions)
+ if (conversion(type, types) != llvm::None)
+ return;
+ types.push_back(type);
+}
+
+/// This method returns ResultConversionKind for the input type.
+BufferAssignmentTypeConverter::ResultConversionKind
+BufferAssignmentTypeConverter::getResultConversionKind(Type origin,
+ Type converted) {
+ for (auto conversion : resultTypeConversions) {
+ auto res = conversion(origin, converted);
+ if (res != llvm::None)
+ return res.getValue();
+ }
+ return KeepAsFunctionResult;
+}
+
+//===----------------------------------------------------------------------===//
+// BufferAssignmentFuncOpConverter
+//===----------------------------------------------------------------------===//
+
+/// Performs the actual function signature rewriting step.
+LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite(
+ mlir::FuncOp funcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ auto funcType = funcOp.getType();
+
+ // Convert function arguments using the provided TypeConverter.
+ TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
+ for (auto argType : llvm::enumerate(funcType.getInputs())) {
+ SmallVector<Type, 2> decomposedTypes, convertedTypes;
+ converter->tryDecomposeType(argType.value(), decomposedTypes);
+ converter->convertTypes(decomposedTypes, convertedTypes);
+ conversion.addInputs(argType.index(), convertedTypes);
+ }
+
+ // Convert the result types of the function.
+ SmallVector<Type, 2> newResultTypes;
+ newResultTypes.reserve(funcOp.getNumResults());
+ for (Type resultType : funcType.getResults()) {
+ SmallVector<Type, 2> originTypes;
+ converter->tryDecomposeType(resultType, originTypes);
+ for (auto origin : originTypes) {
+ Type converted = converter->convertType(origin);
+ auto kind = converter->getResultConversionKind(origin, converted);
+ if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList)
+ conversion.addInputs(converted);
+ else
+ // kind = BufferAssignmentTypeConverter::KeepAsFunctionResult
+ newResultTypes.push_back(converted);
+ }
+ }
+
+ if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter,
+ &conversion)))
+ return failure();
+
+ // Update the signature of the function.
+ rewriter.updateRootInPlace(funcOp, [&] {
+ funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
+ newResultTypes));
+ });
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// BufferAssignmentCallOpConverter
+//===----------------------------------------------------------------------===//
+
+/// Performs the actual rewriting step.
+LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
+ CallOp callOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+
+ // This class represents a mapping from a result to a list of values and some
+ // results that have not yet constructed. Instead, the indices of these
+ // results in the operation that will be constructed are known. They will be
+ // replaced with the actual values when they are available. The order of
+ // adding to this mapping is important.
+ class ResultMapping {
+ public:
+ ResultMapping() { order = 0; };
+
+ /// Add an available value to the mapping.
+ void addMapping(Value value) {
+ toValuesMapping.push_back({order++, value});
+ }
+
+ /// Add the index of unavailble result value to the mapping.
+ void addMapping(unsigned index) {
+ toIndicesMapping.push_back({order++, index});
+ }
+
+ /// This method returns the mapping values list. The unknown result values
+ /// that only their indicies are available are replaced with their values.
+ void getMappingValues(ValueRange valuesToReplaceIndices,
+ SmallVectorImpl<Value> &values) {
+ // Append available values to the list.
+ SmallVector<std::pair<unsigned, Value>, 2> res(toValuesMapping.begin(),
+ toValuesMapping.end());
+ // Replace the indices with the actual values.
+ llvm::for_each(
+ toIndicesMapping, [&](const std::pair<unsigned, unsigned> &entry) {
+ assert(entry.second < valuesToReplaceIndices.size() &&
+ "The value index is out of range.");
+ res.push_back({entry.first, valuesToReplaceIndices[entry.second]});
+ });
+ // Sort the values based on their adding orders.
+ llvm::sort(res, [](const std::pair<unsigned, Value> &v1,
+ const std::pair<unsigned, Value> &v2) {
+ return v1.first < v2.first;
+ });
+ // Fill the values.
+ llvm::for_each(res, [&](const std::pair<unsigned, Value> &entry) {
+ values.push_back(entry.second);
+ });
+ }
+
+ private:
+ /// Keeping the inserting order of mapping values.
+ int order;
+
+ /// Containing the mapping values with their inserting orders.
+ SmallVector<std::pair<unsigned, Value>, 2> toValuesMapping;
+
+ /// Containing the indices of result values with their inserting orders.
+ SmallVector<std::pair<unsigned, unsigned>, 2> toIndicesMapping;
+ };
+
+ Location loc = callOp.getLoc();
+ OpBuilder builder(callOp);
+ SmallVector<Value, 2> newOperands;
+
+ // Create the operands list of the new `CallOp`. It unpacks the decomposable
+ // values if a decompose callback function has been provided by the user.
+ for (auto operand : operands) {
+ SmallVector<Value, 2> values;
+ this->converter->tryDecomposeValue(builder, loc, operand.getType(), operand,
+ values);
+ newOperands.append(values.begin(), values.end());
+ }
+
+ // Create the new result types for the new `CallOp` and a mapping from the old
+ // result to new value(s).
+ SmallVector<Type, 2> newResultTypes;
+ SmallVector<ResultMapping, 4> mappings;
+ mappings.resize(callOp.getNumResults());
+ for (auto result : llvm::enumerate(callOp.getResults())) {
+ SmallVector<Type, 2> originTypes;
+ converter->tryDecomposeType(result.value().getType(), originTypes);
+ auto &resultMapping = mappings[result.index()];
+ for (Type origin : originTypes) {
+ Type converted = converter->convertType(origin);
+ auto kind = converter->getResultConversionKind(origin, converted);
+ if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) {
+ newResultTypes.push_back(converted);
+ // The result value is not yet available. Its index is kept and it is
+ // replaced with the actual value of the new `CallOp` later.
+ resultMapping.addMapping(newResultTypes.size() - 1);
+ } else {
+ // kind = BufferAssignmentTypeConverter::AppendToArgumentsList
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.restoreInsertionPoint(
+ bufferAssignment->computeAllocPosition(result.value()));
+ MemRefType memref = converted.dyn_cast<MemRefType>();
+ if (!memref)
+ return callOp.emitError("Cannot allocate for a non-Memref type");
+ Value alloc = rewriter.create<AllocOp>(loc, memref);
+ newOperands.push_back(alloc);
+ resultMapping.addMapping(alloc);
+ }
+ }
+ }
+
+ CallOp newCallOp = rewriter.create<CallOp>(loc, callOp.getCallee(),
+ newResultTypes, newOperands);
+
+ // Build a replacing value for each result to replace its uses. If a result
+ // has multiple mapping values, it needs to be packed to a single value.
+ OpBuilder nextBuilder(callOp.getOperation()->getNextNode());
+ SmallVector<Value, 2> replacedValues;
+ replacedValues.reserve(callOp.getNumResults());
+ for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) {
+ SmallVector<Value, 2> valuesToPack;
+ mappings[i].getMappingValues(newCallOp.getResults(), valuesToPack);
+ if (valuesToPack.empty()) {
+ // No replacement is required.
+ replacedValues.push_back(nullptr);
+ } else if (valuesToPack.size() == 1) {
+ replacedValues.push_back(valuesToPack.front());
+ } else {
+ // Values need to be packed using callback function. The same callback
+ // that is used for materializeArgumentConversion is used for packing.
+ Value packed = converter->materializeArgumentConversion(
+ nextBuilder, loc, callOp.getType(i), valuesToPack);
+ replacedValues.push_back(packed);
+ }
+ }
+ rewriter.replaceOp(callOp, replacedValues);
+ return success();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
index 084ac38af6e3..e1dacdf0184e 100644
--- a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
+++ b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
@@ -111,7 +111,73 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
// CHECK: %[[Y:.*]]:2 = call @callee(%[[X]]#0)
// CHECK: return %[[Y]]#0
+// -----
+
+// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the
+// signature of the new signature of the callee function when there are tuple typed
+// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed
+// arguments. The tuple typed values should be decomposed and composed using
+// get_tuple_element and make_tuple operations of test dialect. Tensor types are
+// converted to Memref. Memref typed function results remain as function results.
+// CHECK-LABEL: func @callee
+func @callee(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>){
+ return %arg0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
+}
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>)
+// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>)
+// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
+// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
+// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
+// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
+// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]
+// CHECK-LABEL: func @caller
+func @caller(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> tuple<tensor<2xf32>,i1, tensor<5xf32>>{
+ %x0 = call @callee(%arg0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
+ %y0 = call @callee(%x0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
+ return %y0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
+}
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>)
+// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>)
+// CHECK-NEXT: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
+// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 0 : i32}
+// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 1 : i32}
+// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 2 : i32}
+// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]])
+// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>)
+// CHECK-NEXT: %[[RESULT_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2)
+// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 0 : i32}
+// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 1 : i32}
+// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 2 : i32}
+// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]])
+// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>)
+// CHECK-NEXT: %[[RETURN_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2)
+// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 0 : i32}
+// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 1 : i32}
+// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 2 : i32}
+// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]
+// -----
+// Test case: Testing BufferAssginmnetFuncOpConverter and
+// BufferAssginmentReturnOpConverter to see if the return operation matches with
+// the new function signature when there are tuple typed args and results.
+// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple
+// typed values should be decomposed and composed using get_tuple_element and
+// make_tuple operations of test dialect. Tensor types are converted to Memref.
+// Memref typed function results remain as function results.
+
+// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results
+func @decompose_tuple_typed_function_args_and_results(%arg0: tuple<i1,f32>, %arg1: tensor<10xf32>, %arg2: tuple<i1, tensor<5xf32>>) -> (tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>){
+ return %arg2, %arg1, %arg0 : tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>
+}
+// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>
+// CHECK-SAME: (i1, memref<5xf32>, memref<10xf32>, i1, f32)
+// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]])
+// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]])
+// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32}
+// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32}
+// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32}
+// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32}
+// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[SECOND_TUPLE_SECOND_ELEM]], %[[ARG2]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]]
diff --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir
index 064b0fd7e85a..b1cfdfd690cf 100644
--- a/mlir/test/Transforms/buffer-placement-preparation.mlir
+++ b/mlir/test/Transforms/buffer-placement-preparation.mlir
@@ -285,8 +285,93 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
// CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]])
// CHECK: return
+// -----
+
// CHECK-LABEL: func @func_with_unranked_arg
func @func_with_unranked_arg(%arg0: tensor<*xf32>) {
return
}
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
+
+// -----
+
+// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the
+// signature of the new signature of the callee function when there are tuple typed
+// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed
+// arguments. The tuple typed values should be decomposed and composed using
+// get_tuple_element and make_tuple operations of test dialect. Tensor types are
+// converted to Memref. Memref typed function results are appended to the function
+// arguments list.
+
+// CHECK-LABEL: func @callee
+func @callee(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>){
+ return %arg0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
+}
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>)
+// CHECK-SAME: i1
+// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
+// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
+// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
+// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
+// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]])
+// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]])
+// CHECK-NEXT: return %[[SECOND_ELEM]]
+
+
+// CHECK-LABEL: func @caller
+func @caller(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> tuple<tensor<2xf32>,i1, tensor<5xf32>>{
+ %x0 = call @callee(%arg0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
+ %y0 = call @callee(%x0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
+ return %y0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
+}
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>)
+// CHECK-SAME: i1
+// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
+// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
+// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
+// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
+// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
+// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc()
+// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]])
+// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1
+// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]])
+// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
+// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
+// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
+// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
+// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc()
+// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]])
+// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1
+// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]])
+// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
+// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
+// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
+// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]])
+// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]])
+// CHECK-NEXT: return %[[SECOND_ELEM]]
+
+// -----
+
+// Test case: Testing BufferAssginmnetFuncOpConverter and
+// BufferAssginmentReturnOpConverter to see if the return operation matches with
+// the new function signature when there are tuple typed args and results.
+// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple
+// typed values should be decomposed and composed using get_tuple_element and
+// make_tuple operations of test dialect. Tensor types are converted to Memref.
+// Memref typed function results are appended to the function arguments list.
+
+// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results
+func @decompose_tuple_typed_function_args_and_results(%arg0: tuple<i1,f32>, %arg1: tensor<10xf32>, %arg2: tuple<i1, tensor<5xf32>>) -> (tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>){
+ return %arg2, %arg1, %arg0 : tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>
+}
+// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<5xf32>, %[[RESULT1:.*]]: memref<10xf32>
+// CHECK-SAME: (i1, i1, f32)
+// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]])
+// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]])
+// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32}
+// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32}
+// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32}
+// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32}
+// CHECK-NEXT: linalg.copy(%[[SECOND_TUPLE_SECOND_ELEM]], %[[RESULT0]])
+// CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]])
+// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]]
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bc26a8659831..f03c953396a4 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1669,7 +1669,7 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
let results = (outs AnyType:$result);
let extraClassDeclaration = [{
- static LogicalResult inferReturnTypes(MLIRContext *,
+ static LogicalResult inferReturnTypes(MLIRContext *,
Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
@@ -1679,4 +1679,31 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
}];
}
+//===----------------------------------------------------------------------===//
+// Test BufferPlacement
+//===----------------------------------------------------------------------===//
+
+def GetTupleElementOp: TEST_Op<"get_tuple_element"> {
+ let description = [{
+ Test op that returns a specified element of the tuple.
+ }];
+
+ let arguments = (ins
+ TupleOf<[AnyType]>,
+ I32Attr:$index
+ );
+ let results = (outs AnyType);
+}
+
+def MakeTupleOp: TEST_Op<"make_tuple"> {
+ let description = [{
+ Test op that creates a tuple value from a list of values.
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs
+ );
+ let results = (outs TupleOf<[AnyType]>);
+}
+
#endif // TEST_OPS
diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index 6cc0924191cb..14b72b9fc92a 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -11,6 +11,8 @@
//
//===----------------------------------------------------------------------===//
+#include "TestDialect.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
@@ -109,14 +111,16 @@ struct TestBufferPlacementPreparationPass
void populateTensorLinalgToBufferLinalgConversionPattern(
MLIRContext *context, BufferAssignmentPlacer *placer,
- TypeConverter *converter, OwningRewritePatternList *patterns) {
+ BufferAssignmentTypeConverter *converter,
+ OwningRewritePatternList *patterns) {
populateWithBufferAssignmentOpConversionPatterns<
- mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
- allowMemrefFunctionResults>(context, placer, converter, patterns);
+ mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer,
+ converter, patterns);
patterns->insert<GenericOpConverter>(context, placer, converter);
}
void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<TestDialect>();
registry.insert<linalg::LinalgDialect>();
}
@@ -127,6 +131,8 @@ struct TestBufferPlacementPreparationPass
// Mark all Standard operations legal.
target.addLegalDialect<StandardOpsDialect>();
+ target.addLegalOp<MakeTupleOp>();
+ target.addLegalOp<GetTupleElementOp>();
// Mark all Linalg operations illegal as long as they work on tensors.
auto isLegalOperation = [&](Operation *op) {
@@ -149,6 +155,42 @@ struct TestBufferPlacementPreparationPass
converter.isLegal(&funcOp.getBody());
});
+ auto kind = allowMemrefFunctionResults
+ ? BufferAssignmentTypeConverter::KeepAsFunctionResult
+ : BufferAssignmentTypeConverter::AppendToArgumentsList;
+ converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
+ converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
+ kind);
+
+ converter.addDecomposeTypeConversion(
+ [](TupleType tupleType, SmallVectorImpl<Type> &types) {
+ tupleType.getFlattenedTypes(types);
+ return success();
+ });
+
+ converter.addArgumentMaterialization(
+ [](OpBuilder &builder, TupleType resultType, ValueRange inputs,
+ Location loc) -> Optional<Value> {
+ if (inputs.size() == 1)
+ return llvm::None;
+ TypeRange TypeRange = inputs.getTypes();
+ SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end());
+ TupleType tuple = TupleType::get(types, builder.getContext());
+ mlir::Value value = builder.create<MakeTupleOp>(loc, tuple, inputs);
+ return value;
+ });
+
+ converter.addDecomposeValueConversion([](OpBuilder &builder, Location loc,
+ TupleType resultType, Value value,
+ SmallVectorImpl<Value> &values) {
+ for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
+ Value res = builder.create<GetTupleElementOp>(
+ loc, resultType.getType(i), value, builder.getI32IntegerAttr(i));
+ values.push_back(res);
+ }
+ return success();
+ });
+
// Walk over all the functions to apply buffer assignment.
this->getOperation().walk([&](FuncOp function) -> WalkResult {
OwningRewritePatternList patterns;
More information about the Mlir-commits
mailing list