[Mlir-commits] [mlir] 1b88bbf - Revert "[mlir] Extend BufferAssignmentTypeConverter with result conversion callbacks"
Lei Zhang
llvmlistbot at llvm.org
Wed Sep 2 06:26:10 PDT 2020
Author: Lei Zhang
Date: 2020-09-02T09:24:36-04:00
New Revision: 1b88bbf5eb80b38a4dee129df969d5632993fdd1
URL: https://github.com/llvm/llvm-project/commit/1b88bbf5eb80b38a4dee129df969d5632993fdd1
DIFF: https://github.com/llvm/llvm-project/commit/1b88bbf5eb80b38a4dee129df969d5632993fdd1.diff
LOG: Revert "[mlir] Extend BufferAssignmentTypeConverter with result conversion callbacks"
This reverts commit 94f5d248772ba0f1f9c8b0746fe75a5d246c5540 because
of failing the following tests:
MLIR :: Dialect/Linalg/tensors-to-buffers.mlir
MLIR :: Transforms/buffer-placement-preparation-allowed-memref-results.mlir
MLIR :: Transforms/buffer-placement-preparation.mlir
Added:
Modified:
mlir/include/mlir/Transforms/BufferPlacement.h
mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
mlir/lib/Transforms/BufferPlacement.cpp
mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
mlir/test/Transforms/buffer-placement-preparation.mlir
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Transforms/TestBufferPlacement.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h
index 8fc254e6be1e..f8559a9dd939 100644
--- a/mlir/include/mlir/Transforms/BufferPlacement.h
+++ b/mlir/include/mlir/Transforms/BufferPlacement.h
@@ -52,111 +52,6 @@ class BufferAssignmentPlacer {
Operation *operation;
};
-/// A helper type converter class for using inside Buffer Assignment operation
-/// conversion patterns. The default constructor keeps all the types intact
-/// except for the ranked-tensor types which is converted to memref types.
-class BufferAssignmentTypeConverter : public TypeConverter {
-public:
- /// This enum is for showing how buffer placement operation converters should
- /// conduct with certain result type after type conversion. This value can be
- /// set/get for each specific type using setResultConversionKind or
- /// getResultConversionKind.
- enum ResultConversionKind { AppendToArgumentsList, KeepAsFunctionResult };
-
- BufferAssignmentTypeConverter();
-
- /// This method tries to decompose a value of a certain type using provided
- /// decompose callback functions. If it is unable to do so, the original value
- /// is returned.
- void tryDecomposeValue(OpBuilder &, Location, Type, Value,
- SmallVectorImpl<Value> &);
-
- /// This method tries to decompose a type using provided decompose callback
- /// functions. If it is unable to do so, the original type is returned.
- void tryDecomposeType(Type, SmallVectorImpl<Type> &);
-
- /// This method registers a callback function that will be called to decompose
- /// a value of a certain type into several values.
- template <typename FnT,
- typename T = typename llvm::function_traits<FnT>::template arg_t<2>>
- void addDecomposeValueConversion(FnT &&callback) {
- decomposeValueConversions.emplace_back(
- wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
- }
-
- /// This method registers a callback function that will be called to decompose
- /// a type into several types.
- template <typename FnT,
- typename T = typename llvm::function_traits<FnT>::template arg_t<0>>
- void addDecomposeTypeConversion(FnT &&callback) {
- auto wrapper =
- wrapDecomposeTypeConversionCallback<T>(std::forward<FnT>(callback));
- decomposeTypeConversions.emplace_back(wrapper);
- addConversion(std::forward<FnT>(callback));
- }
-
- /// This method returns ResultConversionKind for the mapping from `origin`
- /// type to `input` type.
- ResultConversionKind getResultConversionKind(Type origin, Type input);
-
- /// This method registers ResultConversionKind for the mapping from type 'T'
- /// to type 'U'.
- template <typename T, typename U>
- void setResultConversionKind(ResultConversionKind kind) {
- assert((kind != AppendToArgumentsList ||
- llvm::is_one_of<U, MemRefType, UnrankedMemRefType>::value) &&
- "Only the memref typed values can be set to be appended to the "
- "function argument list at the moment");
- resultTypeConversions.emplace_back(
- [&](Type origin, Type input) -> Optional<ResultConversionKind> {
- if (origin.template isa<T>() && input.template isa<U>())
- return kind;
- return llvm::None;
- });
- }
-
-private:
- using DecomposeValueConversionCallFn = std::function<Optional<LogicalResult>(
- OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>;
-
- using DecomposeTypeConversionCallFn =
- std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
-
- using ResultConversionKindFn =
- std::function<Optional<ResultConversionKind>(Type, Type)>;
-
- /// Generate a wrapper for the given decompose value conversion callback.
- template <typename T, typename FnT>
- DecomposeValueConversionCallFn
- wrapDecomposeValueConversionCallback(FnT &&callback) {
- return [callback = std::forward<FnT>(callback)](
- OpBuilder &builder, Location loc, Type type, Value value,
- SmallVectorImpl<Value> &newValues) -> Optional<LogicalResult> {
- if (T derivedType = type.dyn_cast<T>())
- return callback(builder, loc, derivedType, value, newValues);
- return llvm::None;
- };
- }
-
- /// Generate a wrapper for the given decompose type conversion callback.
- template <typename T, typename FnT>
- DecomposeTypeConversionCallFn
- wrapDecomposeTypeConversionCallback(FnT &&callback) {
- return [callback = std::forward<FnT>(callback)](
- Type type,
- SmallVectorImpl<Type> &results) -> Optional<LogicalResult> {
- T derivedType = type.dyn_cast<T>();
- if (!derivedType)
- return llvm::None;
- return callback(derivedType, results);
- };
- }
-
- SmallVector<ResultConversionKindFn, 2> resultTypeConversions;
- SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions;
- SmallVector<DecomposeTypeConversionCallFn, 2> decomposeTypeConversions;
-};
-
/// Helper conversion pattern that encapsulates a BufferAssignmentPlacer
/// instance. Sample usage:
/// class CustomConversionPattern : public
@@ -173,22 +68,43 @@ class BufferAssignmentOpConversionPattern
public:
explicit BufferAssignmentOpConversionPattern(
MLIRContext *context, BufferAssignmentPlacer *bufferAssignment = nullptr,
- BufferAssignmentTypeConverter *converter = nullptr,
- PatternBenefit benefit = 1)
+ TypeConverter *converter = nullptr, PatternBenefit benefit = 1)
: OpConversionPattern<SourceOp>(context, benefit),
- bufferAssignment(bufferAssignment), converter(converter) {
- assert(converter && "The type converter has not been defined");
- }
+ bufferAssignment(bufferAssignment), converter(converter) {}
protected:
BufferAssignmentPlacer *bufferAssignment;
- BufferAssignmentTypeConverter *converter;
+ TypeConverter *converter;
+};
+
+/// A helper type converter class for using inside Buffer Assignment operation
+/// conversion patterns. The default constructor keeps all the types intact
+/// except for the ranked-tensor types which is converted to memref types.
+class BufferAssignmentTypeConverter : public TypeConverter {
+public:
+ BufferAssignmentTypeConverter();
+
+ /// A helper function to check if `type` has been converted from non-memref
+ /// type to memref.
+ static bool isConvertedMemref(Type type, Type before);
};
-/// Converts the signature of the function using BufferAssignmentTypeConverter.
-/// Each result type of the function is kept as a function result or appended to
-/// the function arguments list based on ResultConversionKind for the converted
-/// result type.
+namespace detail {
+
+/// Converts the signature of the function based on whether the function is
+/// allowed to return memref typed results or not using
+/// `allowMemrefFunctionResults` parameter. If this option is false, then it
+/// adds an extra function argument as an output buffer for each function result
+/// which is going to be a memref type only after type conversion. The
+/// other function result types remain unchanged. If
+/// `allowMemrefFunctionResults` is true, the types are converted in place.
+/// Any changes in function signature need to be applied
+/// to return and caller operations. `BufferAssignmentReturnOpConverter` and
+/// `BufferAssignmentCallOpConverter` are two helper function that match the
+/// return and caller operation with the new function signature. Furthermore,
+/// `BufferAssignmentTypeConverter` is a helper `TypeConverter` for converting
+/// tensor typed values to memref typed ones.
+template <bool allowMemrefFunctionResults>
class BufferAssignmentFuncOpConverter
: public BufferAssignmentOpConversionPattern<FuncOp> {
public:
@@ -196,16 +112,58 @@ class BufferAssignmentFuncOpConverter
FuncOp>::BufferAssignmentOpConversionPattern;
/// Performs the actual signature rewriting step.
- LogicalResult matchAndRewrite(mlir::FuncOp, ArrayRef<Value>,
- ConversionPatternRewriter &) const;
+ LogicalResult
+ matchAndRewrite(mlir::FuncOp funcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ if (!converter)
+ return funcOp.emitError("The type converter has not been defined for "
+ "BufferAssignmentFuncOpConverter");
+ auto funcType = funcOp.getType();
+
+ // Convert function arguments using the provided TypeConverter.
+ TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
+ for (auto argType : llvm::enumerate(funcType.getInputs()))
+ conversion.addInputs(argType.index(),
+ converter->convertType(argType.value()));
+
+ // If allowMemrefFunctionResults is false and a function result type is not
+ // a memref but it would be a memref after type conversion, a new argument
+ // should be appended to the function arguments list for this result.
+ // Otherwise, it remains unchanged as a function result.
+ SmallVector<Type, 2> newResultTypes;
+ newResultTypes.reserve(funcOp.getNumResults());
+ for (Type resType : funcType.getResults()) {
+ Type convertedType = converter->convertType(resType);
+ if (!allowMemrefFunctionResults &&
+ BufferAssignmentTypeConverter::isConvertedMemref(convertedType,
+ resType))
+ conversion.addInputs(convertedType);
+ else
+ newResultTypes.push_back(convertedType);
+ }
+ if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter,
+ &conversion)))
+ return failure();
+
+ // Update the signature of the function.
+ rewriter.updateRootInPlace(funcOp, [&] {
+ funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
+ newResultTypes));
+ });
+ return success();
+ }
};
/// Rewrites the `ReturnOp` to conform with the changed function signature.
-/// Operands that correspond to return values and their types have been set to
-/// AppendToArgumentsList are dropped. In their place, a corresponding copy
-/// operation from the operand to the target function argument is inserted.
+/// if allowMemrefFunctionResults is false, operands that correspond to return
+/// values and have been rewritten from illegal typed results to memref
+/// arguments are dropped. In their place, a corresponding copy operation from
+/// the operand to the output function argument is inserted. Otherwise, the
+/// memref typed operands are returned.
+/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
+/// allowMemrefFunctionResults must be set/unset for both.
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
- typename CopyOpTy>
+ typename CopyOpTy, bool allowMemrefFunctionResults>
class BufferAssignmentReturnOpConverter
: public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> {
public:
@@ -216,48 +174,44 @@ class BufferAssignmentReturnOpConverter
LogicalResult
matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- Location loc = returnOp.getLoc();
-
- // Split the operands depending on whether they need a copy operation or
- // they remain as operands of the return operation. If an operand is
- // decomposable and a decompose callback function has been provided by the
- // user, it will be unpacked.
- SmallVector<Value, 2> newOperands, needCopyOperands;
- OpBuilder builder(returnOp);
- for (auto operand : llvm::enumerate(operands)) {
- SmallVector<Value, 2> values;
- this->converter->tryDecomposeValue(
- builder, loc, operand.value().getType(), operand.value(), values);
- Type type = returnOp.getOperand(operand.index()).getType();
- SmallVector<Type, 2> originTypes;
- this->converter->tryDecomposeType(type, originTypes);
- for (auto value : llvm::enumerate(values)) {
- Type origin = originTypes[value.index()];
- Type converted = value.value().getType();
- auto kind = this->converter->getResultConversionKind(origin, converted);
- if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult)
- newOperands.push_back(value.value());
- else
- // kind = BufferAssignmentTypeConverter::AppendToArgumentsList
- needCopyOperands.push_back(value.value());
- }
+ // If the memref typed results can be returned as function results, the new
+ // `ReturnOp` should only return the type converted operands.
+ if (allowMemrefFunctionResults) {
+ rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, operands);
+ return success();
}
- // Insert Copy operations instead for the operands that have been removed
- // from operand list and appended to the function arguments list.
+ // Split the operands by their kinds whether they are converted memref or
+ // not.
+ SmallVector<Value, 2> needCopyOperands, newOperands;
+ unsigned operandsSize = operands.size();
+ needCopyOperands.reserve(operandsSize);
+ newOperands.reserve(operandsSize);
+ for (auto operand : llvm::enumerate(operands))
+ if (BufferAssignmentTypeConverter::isConvertedMemref(
+ operand.value().getType(),
+ returnOp.getOperand(operand.index()).getType()))
+ needCopyOperands.push_back(operand.value());
+ else
+ newOperands.push_back(operand.value());
+
Block &entryBlock = returnOp.getParentRegion()->front();
unsigned numFuncArgs = entryBlock.getNumArguments();
- if (needCopyOperands.size() > numFuncArgs)
- return returnOp.emitError(
- "The number of operands that need Copy operations is more "
- "than the number of target function arguments.");
+
+ // Find the index of the first destination buffer.
+ assert(needCopyOperands.size() <= numFuncArgs &&
+ "The number of operands of return operation is more than the "
+ "number of function arguments.");
unsigned destArgNum = numFuncArgs - needCopyOperands.size();
rewriter.setInsertionPoint(returnOp);
for (Value operand : needCopyOperands) {
- rewriter.create<CopyOpTy>(loc, operand,
+ // Insert a `CopyOp` for each converted memref-type operand.
+ rewriter.create<CopyOpTy>(returnOp.getLoc(), operand,
entryBlock.getArgument(destArgNum));
++destArgNum;
}
+
+ // Insert the new target Return operation.
rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, newOperands);
return success();
}
@@ -265,32 +219,94 @@ class BufferAssignmentReturnOpConverter
/// Rewrites the `CallOp` to match its operands and results with the signature
/// of the callee after rewriting the callee with
-/// BufferAssignmentFuncOpConverter.
+/// BufferAssignmentFuncOpConverter. If allowMemrefFunctionResults is false, a
+/// buffer is allocated as an output buffer only for each memref typed result
+/// that has been rewritten. The new allocated buffer is passed through the
+/// operands list of the new `CallOp`.
+/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
+/// allowMemrefFunctionResults must be set/unset for both.
+template <bool allowMemrefFunctionResults>
class BufferAssignmentCallOpConverter
: public BufferAssignmentOpConversionPattern<CallOp> {
public:
using BufferAssignmentOpConversionPattern<
CallOp>::BufferAssignmentOpConversionPattern;
- /// Performs the actual rewriting step.
- LogicalResult matchAndRewrite(CallOp, ArrayRef<Value>,
- ConversionPatternRewriter &) const;
+ LogicalResult
+ matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ if (!converter)
+ return callOp.emitError("The type converter has not been defined for "
+ "BufferAssignmentCallOpConverter");
+ Location loc = callOp.getLoc();
+
+ // If the memref typed results can be returned as function results, there is
+ // no need to create output buffers. It is only required to convert the type
+ // of operands and results in place for creating the new `CallOp`.
+ if (allowMemrefFunctionResults) {
+ SmallVector<Type, 2> resultTypes;
+ resultTypes.reserve(callOp.getNumResults());
+ for (Type type : callOp.getResultTypes())
+ resultTypes.push_back(converter->convertType(type));
+ rewriter.replaceOpWithNewOp<CallOp>(callOp, callOp.getCallee(),
+ resultTypes, operands);
+ return success();
+ }
+
+ SmallVector<Value, 2> newOperands, replacingValues;
+ SmallVector<Type, 2> newResultTypes;
+ unsigned numResults = callOp.getNumResults();
+ newOperands.reserve(numResults + operands.size());
+ newOperands.append(operands.begin(), operands.end());
+ newResultTypes.reserve(numResults);
+ replacingValues.reserve(numResults);
+
+ // For each memref result of `CallOp` which has not been a memref before
+ // the type conversion, a new buffer is allocated and passed to the operands
+ // list of the new `CallOp`. Otherwise, it remains as a caller result.
+ for (Value result : callOp.getResults()) {
+ Type currType = result.getType();
+ Type newType = converter->convertType(result.getType());
+ if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.restoreInsertionPoint(bufferAssignment->computeAllocPosition(
+ result.dyn_cast<OpResult>()));
+ Value alloc =
+ rewriter.create<AllocOp>(loc, newType.dyn_cast<MemRefType>());
+ newOperands.push_back(alloc);
+ replacingValues.push_back(alloc);
+ } else {
+ newResultTypes.push_back(currType);
+
+ // No replacing is required.
+ replacingValues.push_back(nullptr);
+ }
+ }
+
+ // Creating the new `CallOp`.
+ rewriter.create<CallOp>(loc, callOp.getCallee(), newResultTypes,
+ newOperands);
+
+ // Replacing the results of the old `CallOp`.
+ rewriter.replaceOp(callOp, replacingValues);
+ return success();
+ }
};
+} // end namespace detail
/// Populates `patterns` with the conversion patterns of buffer
/// assignment.
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
- typename CopyOpTy>
+ typename CopyOpTy, bool allowMemrefFunctionResults>
static void populateWithBufferAssignmentOpConversionPatterns(
MLIRContext *context, BufferAssignmentPlacer *placer,
- BufferAssignmentTypeConverter *converter,
- OwningRewritePatternList *patterns) {
+ TypeConverter *converter, OwningRewritePatternList *patterns) {
// clang-format off
patterns->insert<
- BufferAssignmentCallOpConverter,
- BufferAssignmentFuncOpConverter,
- BufferAssignmentReturnOpConverter
- <ReturnOpSourceTy, ReturnOpTargetTy, CopyOpTy>
+ detail::BufferAssignmentCallOpConverter<allowMemrefFunctionResults>,
+ detail::BufferAssignmentFuncOpConverter<allowMemrefFunctionResults>,
+ detail::BufferAssignmentReturnOpConverter
+ <ReturnOpSourceTy, ReturnOpTargetTy, CopyOpTy, allowMemrefFunctionResults>
>(context, placer, converter);
// clang-format on
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index 89a01f9ca629..04c1fbd5d565 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -100,11 +100,11 @@ class GenericOpConverter
/// tensors to buffers.
static void populateConvertLinalgOnTensorsToBuffersPattern(
MLIRContext *context, BufferAssignmentPlacer *placer,
- BufferAssignmentTypeConverter *converter,
- OwningRewritePatternList *patterns) {
+ TypeConverter *converter, OwningRewritePatternList *patterns) {
populateWithBufferAssignmentOpConversionPatterns<
- mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer,
- converter, patterns);
+ mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
+ /*allowMemrefFunctionResults=*/false>(context, placer, converter,
+ patterns);
patterns->insert<GenericOpConverter>(context, placer, converter);
}
@@ -141,9 +141,6 @@ struct ConvertLinalgOnTensorsToBuffers
converter.isLegal(&funcOp.getBody());
});
- converter.setResultConversionKind<RankedTensorType, MemRefType>(
- BufferAssignmentTypeConverter::AppendToArgumentsList);
-
// Walk over all the functions to apply buffer assignment.
getOperation().walk([&](FuncOp function) -> WalkResult {
OwningRewritePatternList patterns;
diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp
index 1ab3e7e2e48d..201570a244ff 100644
--- a/mlir/lib/Transforms/BufferPlacement.cpp
+++ b/mlir/lib/Transforms/BufferPlacement.cpp
@@ -713,223 +713,9 @@ BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() {
});
}
-/// This method tries to decompose a value of a certain type using provided
-/// decompose callback functions. If it is unable to do so, the original value
-/// is returned.
-void BufferAssignmentTypeConverter::tryDecomposeValue(
- OpBuilder &builder, Location loc, Type type, Value value,
- SmallVectorImpl<Value> &results) {
- for (auto conversion : decomposeValueConversions)
- if (conversion(builder, loc, type, value, results) != llvm::None)
- return;
- results.push_back(value);
-}
-
-/// This method tries to decompose a type using provided decompose callback
-/// functions. If it is unable to do so, the original type is returned.
-void BufferAssignmentTypeConverter::tryDecomposeType(
- Type type, SmallVectorImpl<Type> &types) {
- for (auto conversion : decomposeTypeConversions)
- if (conversion(type, types) != llvm::None)
- return;
- types.push_back(type);
-}
-
-/// This method returns ResultConversionKind for the input type.
-BufferAssignmentTypeConverter::ResultConversionKind
-BufferAssignmentTypeConverter::getResultConversionKind(Type origin,
- Type converted) {
- for (auto conversion : resultTypeConversions) {
- auto res = conversion(origin, converted);
- if (res != llvm::None)
- return res.getValue();
- }
- return KeepAsFunctionResult;
-}
-
-//===----------------------------------------------------------------------===//
-// BufferAssignmentFuncOpConverter
-//===----------------------------------------------------------------------===//
-
-/// Performs the actual function signature rewriting step.
-LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite(
- mlir::FuncOp funcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- auto funcType = funcOp.getType();
-
- // Convert function arguments using the provided TypeConverter.
- TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
- for (auto argType : llvm::enumerate(funcType.getInputs())) {
- SmallVector<Type, 2> decomposedTypes, convertedTypes;
- converter->tryDecomposeType(argType.value(), decomposedTypes);
- converter->convertTypes(decomposedTypes, convertedTypes);
- conversion.addInputs(argType.index(), convertedTypes);
- }
-
- // Convert the result types of the function.
- SmallVector<Type, 2> newResultTypes;
- newResultTypes.reserve(funcOp.getNumResults());
- for (Type resultType : funcType.getResults()) {
- SmallVector<Type, 2> originTypes;
- converter->tryDecomposeType(resultType, originTypes);
- for (auto origin : originTypes) {
- Type converted = converter->convertType(origin);
- auto kind = converter->getResultConversionKind(origin, converted);
- if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList)
- conversion.addInputs(converted);
- else
- // kind = BufferAssignmentTypeConverter::KeepAsFunctionResult
- newResultTypes.push_back(converted);
- }
- }
-
- if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter,
- &conversion)))
- return failure();
-
- // Update the signature of the function.
- rewriter.updateRootInPlace(funcOp, [&] {
- funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
- newResultTypes));
- });
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// BufferAssignmentCallOpConverter
-//===----------------------------------------------------------------------===//
-
-/// Performs the actual rewriting step.
-LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
- CallOp callOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
-
- // This class represents a mapping from a result to a list of values and some
- // results that have not yet constructed. Instead, the indices of these
- // results in the operation that will be constructed are known. They will be
- // replaced with the actual values when they are available. The order of
- // adding to this mapping is important.
- class ResultMapping {
- public:
- ResultMapping() { order = 0; };
-
- /// Add an available value to the mapping.
- void addMapping(Value value) {
- toValuesMapping.push_back({order++, value});
- }
-
- /// Add the index of unavailble result value to the mapping.
- void addMapping(unsigned index) {
- toIndicesMapping.push_back({order++, index});
- }
-
- /// This method returns the mapping values list. The unknown result values
- /// that only their indicies are available are replaced with their values.
- void getMappingValues(ValueRange valuesToReplaceIndices,
- SmallVectorImpl<Value> &values) {
- // Append available values to the list.
- SmallVector<std::pair<unsigned, Value>, 2> res(toValuesMapping.begin(),
- toValuesMapping.end());
- // Replace the indices with the actual values.
- llvm::for_each(
- toIndicesMapping, [&](const std::pair<unsigned, unsigned> &entry) {
- assert(entry.second < valuesToReplaceIndices.size() &&
- "The value index is out of range.");
- res.push_back({entry.first, valuesToReplaceIndices[entry.second]});
- });
- // Sort the values based on their adding orders.
- llvm::sort(res, [](const std::pair<unsigned, Value> &v1,
- const std::pair<unsigned, Value> &v2) {
- return v1.first < v2.first;
- });
- // Fill the values.
- llvm::for_each(res, [&](const std::pair<unsigned, Value> &entry) {
- values.push_back(entry.second);
- });
- }
-
- private:
- /// Keeping the inserting order of mapping values.
- int order;
-
- /// Containing the mapping values with their inserting orders.
- SmallVector<std::pair<unsigned, Value>, 2> toValuesMapping;
-
- /// Containing the indices of result values with their inserting orders.
- SmallVector<std::pair<unsigned, unsigned>, 2> toIndicesMapping;
- };
-
- Location loc = callOp.getLoc();
- OpBuilder builder(callOp);
- SmallVector<Value, 2> newOperands;
-
- // Create the operands list of the new `CallOp`. It unpacks the decomposable
- // values if a decompose callback function has been provided by the user.
- for (auto operand : operands) {
- SmallVector<Value, 2> values;
- this->converter->tryDecomposeValue(builder, loc, operand.getType(), operand,
- values);
- newOperands.append(values.begin(), values.end());
- }
-
- // Create the new result types for the new `CallOp` and a mapping from the old
- // result to new value(s).
- SmallVector<Type, 2> newResultTypes;
- SmallVector<ResultMapping, 4> mappings;
- mappings.resize(callOp.getNumResults());
- for (auto result : llvm::enumerate(callOp.getResults())) {
- SmallVector<Type, 2> originTypes;
- converter->tryDecomposeType(result.value().getType(), originTypes);
- auto &resultMapping = mappings[result.index()];
- for (Type origin : originTypes) {
- Type converted = converter->convertType(origin);
- auto kind = converter->getResultConversionKind(origin, converted);
- if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) {
- newResultTypes.push_back(converted);
- // The result value is not yet available. Its index is kept and it is
- // replaced with the actual value of the new `CallOp` later.
- resultMapping.addMapping(newResultTypes.size() - 1);
- } else {
- // kind = BufferAssignmentTypeConverter::AppendToArgumentsList
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.restoreInsertionPoint(
- bufferAssignment->computeAllocPosition(result.value()));
- MemRefType memref = converted.dyn_cast<MemRefType>();
- if (!memref)
- return callOp.emitError("Cannot allocate for a non-Memref type");
- Value alloc = rewriter.create<AllocOp>(loc, memref);
- newOperands.push_back(alloc);
- resultMapping.addMapping(alloc);
- }
- }
- }
-
- CallOp newCallOp = rewriter.create<CallOp>(loc, callOp.getCallee(),
- newResultTypes, newOperands);
-
- // Build a replacing value for each result to replace its uses. If a result
- // has multiple mapping values, it needs to be packed to a single value.
- OpBuilder nextBuilder(callOp.getOperation()->getNextNode());
- SmallVector<Value, 2> replacedValues;
- replacedValues.reserve(callOp.getNumResults());
- for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) {
- SmallVector<Value, 2> valuesToPack;
- mappings[i].getMappingValues(newCallOp.getResults(), valuesToPack);
- if (valuesToPack.empty()) {
- // No replacement is required.
- replacedValues.push_back(nullptr);
- } else if (valuesToPack.size() == 1) {
- replacedValues.push_back(valuesToPack.front());
- } else {
- // Values need to be packed using callback function. The same callback
- // that is used for materializeArgumentConversion is used for packing.
- Value packed = converter->materializeArgumentConversion(
- nextBuilder, loc, callOp.getType(i), valuesToPack);
- replacedValues.push_back(packed);
- }
- }
- rewriter.replaceOp(callOp, replacedValues);
- return success();
+/// Checks if `type` has been converted from non-memref type to memref.
+bool BufferAssignmentTypeConverter::isConvertedMemref(Type type, Type before) {
+ return type.isa<BaseMemRefType>() && !before.isa<BaseMemRefType>();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
index e1dacdf0184e..084ac38af6e3 100644
--- a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
+++ b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
@@ -111,73 +111,7 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
// CHECK: %[[Y:.*]]:2 = call @callee(%[[X]]#0)
// CHECK: return %[[Y]]#0
-// -----
-
-// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the
-// signature of the new signature of the callee function when there are tuple typed
-// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed
-// arguments. The tuple typed values should be decomposed and composed using
-// get_tuple_element and make_tuple operations of test dialect. Tensor types are
-// converted to Memref. Memref typed function results remain as function results.
-// CHECK-LABEL: func @callee
-func @callee(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>){
- return %arg0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
-}
-// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>)
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>)
-// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]
-// CHECK-LABEL: func @caller
-func @caller(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> tuple<tensor<2xf32>,i1, tensor<5xf32>>{
- %x0 = call @callee(%arg0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
- %y0 = call @callee(%x0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
- return %y0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
-}
-// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>)
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>)
-// CHECK-NEXT: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]])
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>)
-// CHECK-NEXT: %[[RESULT_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2)
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]])
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>)
-// CHECK-NEXT: %[[RETURN_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2)
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]
-// -----
-// Test case: Testing BufferAssginmnetFuncOpConverter and
-// BufferAssginmentReturnOpConverter to see if the return operation matches with
-// the new function signature when there are tuple typed args and results.
-// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple
-// typed values should be decomposed and composed using get_tuple_element and
-// make_tuple operations of test dialect. Tensor types are converted to Memref.
-// Memref typed function results remain as function results.
-
-// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results
-func @decompose_tuple_typed_function_args_and_results(%arg0: tuple<i1,f32>, %arg1: tensor<10xf32>, %arg2: tuple<i1, tensor<5xf32>>) -> (tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>){
- return %arg2, %arg1, %arg0 : tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>
-}
-// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>
-// CHECK-SAME: (i1, memref<5xf32>, memref<10xf32>, i1, f32)
-// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]])
-// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]])
-// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[SECOND_TUPLE_SECOND_ELEM]], %[[ARG2]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]]
diff --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir
index b1cfdfd690cf..064b0fd7e85a 100644
--- a/mlir/test/Transforms/buffer-placement-preparation.mlir
+++ b/mlir/test/Transforms/buffer-placement-preparation.mlir
@@ -285,93 +285,8 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
// CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]])
// CHECK: return
-// -----
-
// CHECK-LABEL: func @func_with_unranked_arg
func @func_with_unranked_arg(%arg0: tensor<*xf32>) {
return
}
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
-
-// -----
-
-// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the
-// signature of the new signature of the callee function when there are tuple typed
-// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed
-// arguments. The tuple typed values should be decomposed and composed using
-// get_tuple_element and make_tuple operations of test dialect. Tensor types are
-// converted to Memref. Memref typed function results are appended to the function
-// arguments list.
-
-// CHECK-LABEL: func @callee
-func @callee(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>){
- return %arg0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
-}
-// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>)
-// CHECK-SAME: i1
-// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]])
-// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]])
-// CHECK-NEXT: return %[[SECOND_ELEM]]
-
-
-// CHECK-LABEL: func @caller
-func @caller(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> tuple<tensor<2xf32>,i1, tensor<5xf32>>{
- %x0 = call @callee(%arg0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
- %y0 = call @callee(%x0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
- return %y0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
-}
-// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>)
-// CHECK-SAME: i1
-// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
-// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc()
-// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]])
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1
-// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]])
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
-// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc()
-// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]])
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1
-// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]])
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]])
-// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]])
-// CHECK-NEXT: return %[[SECOND_ELEM]]
-
-// -----
-
-// Test case: Testing BufferAssginmnetFuncOpConverter and
-// BufferAssginmentReturnOpConverter to see if the return operation matches with
-// the new function signature when there are tuple typed args and results.
-// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple
-// typed values should be decomposed and composed using get_tuple_element and
-// make_tuple operations of test dialect. Tensor types are converted to Memref.
-// Memref typed function results are appended to the function arguments list.
-
-// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results
-func @decompose_tuple_typed_function_args_and_results(%arg0: tuple<i1,f32>, %arg1: tensor<10xf32>, %arg2: tuple<i1, tensor<5xf32>>) -> (tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>){
- return %arg2, %arg1, %arg0 : tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>
-}
-// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<5xf32>, %[[RESULT1:.*]]: memref<10xf32>
-// CHECK-SAME: (i1, i1, f32)
-// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]])
-// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]])
-// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: linalg.copy(%[[SECOND_TUPLE_SECOND_ELEM]], %[[RESULT0]])
-// CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]])
-// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]]
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index f03c953396a4..bc26a8659831 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1669,7 +1669,7 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
let results = (outs AnyType:$result);
let extraClassDeclaration = [{
- static LogicalResult inferReturnTypes(MLIRContext *,
+ static LogicalResult inferReturnTypes(MLIRContext *,
Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
@@ -1679,31 +1679,4 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
}];
}
-//===----------------------------------------------------------------------===//
-// Test BufferPlacement
-//===----------------------------------------------------------------------===//
-
-def GetTupleElementOp: TEST_Op<"get_tuple_element"> {
- let description = [{
- Test op that returns a specified element of the tuple.
- }];
-
- let arguments = (ins
- TupleOf<[AnyType]>,
- I32Attr:$index
- );
- let results = (outs AnyType);
-}
-
-def MakeTupleOp: TEST_Op<"make_tuple"> {
- let description = [{
- Test op that creates a tuple value from a list of values.
- }];
-
- let arguments = (ins
- Variadic<AnyType>:$inputs
- );
- let results = (outs TupleOf<[AnyType]>);
-}
-
#endif // TEST_OPS
diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index 14b72b9fc92a..6cc0924191cb 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -11,8 +11,6 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
@@ -111,16 +109,14 @@ struct TestBufferPlacementPreparationPass
void populateTensorLinalgToBufferLinalgConversionPattern(
MLIRContext *context, BufferAssignmentPlacer *placer,
- BufferAssignmentTypeConverter *converter,
- OwningRewritePatternList *patterns) {
+ TypeConverter *converter, OwningRewritePatternList *patterns) {
populateWithBufferAssignmentOpConversionPatterns<
- mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer,
- converter, patterns);
+ mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
+ allowMemrefFunctionResults>(context, placer, converter, patterns);
patterns->insert<GenericOpConverter>(context, placer, converter);
}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<TestDialect>();
registry.insert<linalg::LinalgDialect>();
}
@@ -131,8 +127,6 @@ struct TestBufferPlacementPreparationPass
// Mark all Standard operations legal.
target.addLegalDialect<StandardOpsDialect>();
- target.addLegalOp<MakeTupleOp>();
- target.addLegalOp<GetTupleElementOp>();
// Mark all Linalg operations illegal as long as they work on tensors.
auto isLegalOperation = [&](Operation *op) {
@@ -155,42 +149,6 @@ struct TestBufferPlacementPreparationPass
converter.isLegal(&funcOp.getBody());
});
- auto kind = allowMemrefFunctionResults
- ? BufferAssignmentTypeConverter::KeepAsFunctionResult
- : BufferAssignmentTypeConverter::AppendToArgumentsList;
- converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
- converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
- kind);
-
- converter.addDecomposeTypeConversion(
- [](TupleType tupleType, SmallVectorImpl<Type> &types) {
- tupleType.getFlattenedTypes(types);
- return success();
- });
-
- converter.addArgumentMaterialization(
- [](OpBuilder &builder, TupleType resultType, ValueRange inputs,
- Location loc) -> Optional<Value> {
- if (inputs.size() == 1)
- return llvm::None;
- TypeRange TypeRange = inputs.getTypes();
- SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end());
- TupleType tuple = TupleType::get(types, builder.getContext());
- mlir::Value value = builder.create<MakeTupleOp>(loc, tuple, inputs);
- return value;
- });
-
- converter.addDecomposeValueConversion([](OpBuilder &builder, Location loc,
- TupleType resultType, Value value,
- SmallVectorImpl<Value> &values) {
- for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
- Value res = builder.create<GetTupleElementOp>(
- loc, resultType.getType(i), value, builder.getI32IntegerAttr(i));
- values.push_back(res);
- }
- return success();
- });
-
// Walk over all the functions to apply buffer assignment.
this->getOperation().walk([&](FuncOp function) -> WalkResult {
OwningRewritePatternList patterns;
More information about the Mlir-commits
mailing list