[Mlir-commits] [mlir] 1b88bbf - Revert "[mlir] Extend BufferAssignmentTypeConverter with result conversion callbacks"

Lei Zhang llvmlistbot at llvm.org
Wed Sep 2 06:26:10 PDT 2020


Author: Lei Zhang
Date: 2020-09-02T09:24:36-04:00
New Revision: 1b88bbf5eb80b38a4dee129df969d5632993fdd1

URL: https://github.com/llvm/llvm-project/commit/1b88bbf5eb80b38a4dee129df969d5632993fdd1
DIFF: https://github.com/llvm/llvm-project/commit/1b88bbf5eb80b38a4dee129df969d5632993fdd1.diff

LOG: Revert "[mlir] Extend BufferAssignmentTypeConverter with result conversion callbacks"

This reverts commit 94f5d248772ba0f1f9c8b0746fe75a5d246c5540 because
of failing the following tests:

MLIR :: Dialect/Linalg/tensors-to-buffers.mlir
MLIR :: Transforms/buffer-placement-preparation-allowed-memref-results.mlir
MLIR :: Transforms/buffer-placement-preparation.mlir

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/BufferPlacement.h
    mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
    mlir/lib/Transforms/BufferPlacement.cpp
    mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
    mlir/test/Transforms/buffer-placement-preparation.mlir
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Transforms/TestBufferPlacement.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h
index 8fc254e6be1e..f8559a9dd939 100644
--- a/mlir/include/mlir/Transforms/BufferPlacement.h
+++ b/mlir/include/mlir/Transforms/BufferPlacement.h
@@ -52,111 +52,6 @@ class BufferAssignmentPlacer {
   Operation *operation;
 };
 
-/// A helper type converter class for using inside Buffer Assignment operation
-/// conversion patterns. The default constructor keeps all the types intact
-/// except for the ranked-tensor types which is converted to memref types.
-class BufferAssignmentTypeConverter : public TypeConverter {
-public:
-  /// This enum is for showing how buffer placement operation converters should
-  /// conduct with certain result type after type conversion. This value can be
-  /// set/get for each specific type using setResultConversionKind or
-  /// getResultConversionKind.
-  enum ResultConversionKind { AppendToArgumentsList, KeepAsFunctionResult };
-
-  BufferAssignmentTypeConverter();
-
-  /// This method tries to decompose a value of a certain type using provided
-  /// decompose callback functions. If it is unable to do so, the original value
-  /// is returned.
-  void tryDecomposeValue(OpBuilder &, Location, Type, Value,
-                         SmallVectorImpl<Value> &);
-
-  /// This method tries to decompose a type using provided decompose callback
-  /// functions. If it is unable to do so, the original type is returned.
-  void tryDecomposeType(Type, SmallVectorImpl<Type> &);
-
-  /// This method registers a callback function that will be called to decompose
-  /// a value of a certain type into several values.
-  template <typename FnT,
-            typename T = typename llvm::function_traits<FnT>::template arg_t<2>>
-  void addDecomposeValueConversion(FnT &&callback) {
-    decomposeValueConversions.emplace_back(
-        wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
-  }
-
-  /// This method registers a callback function that will be called to decompose
-  /// a type into several types.
-  template <typename FnT,
-            typename T = typename llvm::function_traits<FnT>::template arg_t<0>>
-  void addDecomposeTypeConversion(FnT &&callback) {
-    auto wrapper =
-        wrapDecomposeTypeConversionCallback<T>(std::forward<FnT>(callback));
-    decomposeTypeConversions.emplace_back(wrapper);
-    addConversion(std::forward<FnT>(callback));
-  }
-
-  /// This method returns ResultConversionKind for the mapping from `origin`
-  /// type to `input` type.
-  ResultConversionKind getResultConversionKind(Type origin, Type input);
-
-  /// This method registers ResultConversionKind for the mapping from type 'T'
-  /// to type 'U'.
-  template <typename T, typename U>
-  void setResultConversionKind(ResultConversionKind kind) {
-    assert((kind != AppendToArgumentsList ||
-            llvm::is_one_of<U, MemRefType, UnrankedMemRefType>::value) &&
-           "Only the memref typed values can be set to be appended to the "
-           "function argument list at the moment");
-    resultTypeConversions.emplace_back(
-        [&](Type origin, Type input) -> Optional<ResultConversionKind> {
-          if (origin.template isa<T>() && input.template isa<U>())
-            return kind;
-          return llvm::None;
-        });
-  }
-
-private:
-  using DecomposeValueConversionCallFn = std::function<Optional<LogicalResult>(
-      OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>;
-
-  using DecomposeTypeConversionCallFn =
-      std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
-
-  using ResultConversionKindFn =
-      std::function<Optional<ResultConversionKind>(Type, Type)>;
-
-  /// Generate a wrapper for the given decompose value conversion callback.
-  template <typename T, typename FnT>
-  DecomposeValueConversionCallFn
-  wrapDecomposeValueConversionCallback(FnT &&callback) {
-    return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Location loc, Type type, Value value,
-               SmallVectorImpl<Value> &newValues) -> Optional<LogicalResult> {
-      if (T derivedType = type.dyn_cast<T>())
-        return callback(builder, loc, derivedType, value, newValues);
-      return llvm::None;
-    };
-  }
-
-  /// Generate a wrapper for the given decompose type conversion callback.
-  template <typename T, typename FnT>
-  DecomposeTypeConversionCallFn
-  wrapDecomposeTypeConversionCallback(FnT &&callback) {
-    return [callback = std::forward<FnT>(callback)](
-               Type type,
-               SmallVectorImpl<Type> &results) -> Optional<LogicalResult> {
-      T derivedType = type.dyn_cast<T>();
-      if (!derivedType)
-        return llvm::None;
-      return callback(derivedType, results);
-    };
-  }
-
-  SmallVector<ResultConversionKindFn, 2> resultTypeConversions;
-  SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions;
-  SmallVector<DecomposeTypeConversionCallFn, 2> decomposeTypeConversions;
-};
-
 /// Helper conversion pattern that encapsulates a BufferAssignmentPlacer
 /// instance. Sample usage:
 /// class CustomConversionPattern : public
@@ -173,22 +68,43 @@ class BufferAssignmentOpConversionPattern
 public:
   explicit BufferAssignmentOpConversionPattern(
       MLIRContext *context, BufferAssignmentPlacer *bufferAssignment = nullptr,
-      BufferAssignmentTypeConverter *converter = nullptr,
-      PatternBenefit benefit = 1)
+      TypeConverter *converter = nullptr, PatternBenefit benefit = 1)
       : OpConversionPattern<SourceOp>(context, benefit),
-        bufferAssignment(bufferAssignment), converter(converter) {
-    assert(converter && "The type converter has not been defined");
-  }
+        bufferAssignment(bufferAssignment), converter(converter) {}
 
 protected:
   BufferAssignmentPlacer *bufferAssignment;
-  BufferAssignmentTypeConverter *converter;
+  TypeConverter *converter;
+};
+
+/// A helper type converter class for using inside Buffer Assignment operation
+/// conversion patterns. The default constructor keeps all the types intact
+/// except for the ranked-tensor types which is converted to memref types.
+class BufferAssignmentTypeConverter : public TypeConverter {
+public:
+  BufferAssignmentTypeConverter();
+
+  /// A helper function to check if `type` has been converted from non-memref
+  /// type to memref.
+  static bool isConvertedMemref(Type type, Type before);
 };
 
-/// Converts the signature of the function using BufferAssignmentTypeConverter.
-/// Each result type of the function is kept as a function result or appended to
-/// the function arguments list based on ResultConversionKind for the converted
-/// result type.
+namespace detail {
+
+/// Converts the signature of the function based on whether the function is
+/// allowed to return memref typed results or not using
+/// `allowMemrefFunctionResults` parameter. If this option is false, then it
+/// adds an extra function argument as an output buffer for each function result
+/// which is going to be a memref type only after type conversion. The
+/// other function result types remain unchanged. If
+/// `allowMemrefFunctionResults` is true, the types are converted in place.
+/// Any changes in function signature need to be applied
+/// to return and caller operations. `BufferAssignmentReturnOpConverter` and
+/// `BufferAssignmentCallOpConverter` are two helper function that match the
+/// return and caller operation with the new function signature. Furthermore,
+/// `BufferAssignmentTypeConverter` is a helper `TypeConverter` for converting
+/// tensor typed values to memref typed ones.
+template <bool allowMemrefFunctionResults>
 class BufferAssignmentFuncOpConverter
     : public BufferAssignmentOpConversionPattern<FuncOp> {
 public:
@@ -196,16 +112,58 @@ class BufferAssignmentFuncOpConverter
       FuncOp>::BufferAssignmentOpConversionPattern;
 
   /// Performs the actual signature rewriting step.
-  LogicalResult matchAndRewrite(mlir::FuncOp, ArrayRef<Value>,
-                                ConversionPatternRewriter &) const;
+  LogicalResult
+  matchAndRewrite(mlir::FuncOp funcOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    if (!converter)
+      return funcOp.emitError("The type converter has not been defined for "
+                              "BufferAssignmentFuncOpConverter");
+    auto funcType = funcOp.getType();
+
+    // Convert function arguments using the provided TypeConverter.
+    TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
+    for (auto argType : llvm::enumerate(funcType.getInputs()))
+      conversion.addInputs(argType.index(),
+                           converter->convertType(argType.value()));
+
+    // If allowMemrefFunctionResults is false and a function result type is not
+    // a memref but it would be a memref after type conversion, a new argument
+    // should be appended to the function arguments list for this result.
+    // Otherwise, it remains unchanged as a function result.
+    SmallVector<Type, 2> newResultTypes;
+    newResultTypes.reserve(funcOp.getNumResults());
+    for (Type resType : funcType.getResults()) {
+      Type convertedType = converter->convertType(resType);
+      if (!allowMemrefFunctionResults &&
+          BufferAssignmentTypeConverter::isConvertedMemref(convertedType,
+                                                           resType))
+        conversion.addInputs(convertedType);
+      else
+        newResultTypes.push_back(convertedType);
+    }
+    if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter,
+                                           &conversion)))
+      return failure();
+
+    // Update the signature of the function.
+    rewriter.updateRootInPlace(funcOp, [&] {
+      funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
+                                              newResultTypes));
+    });
+    return success();
+  }
 };
 
 /// Rewrites the `ReturnOp` to conform with the changed function signature.
-/// Operands that correspond to return values and their types have been set to
-/// AppendToArgumentsList are dropped. In their place, a corresponding copy
-/// operation from the operand to the target function argument is inserted.
+/// if allowMemrefFunctionResults is false, operands that correspond to return
+/// values and have been rewritten from illegal typed results to memref
+/// arguments are dropped. In their place, a corresponding copy operation from
+/// the operand to the output function argument is inserted. Otherwise, the
+/// memref typed operands are returned.
+/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
+/// allowMemrefFunctionResults must be set/unset for both.
 template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
-          typename CopyOpTy>
+          typename CopyOpTy, bool allowMemrefFunctionResults>
 class BufferAssignmentReturnOpConverter
     : public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> {
 public:
@@ -216,48 +174,44 @@ class BufferAssignmentReturnOpConverter
   LogicalResult
   matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    Location loc = returnOp.getLoc();
-
-    // Split the operands depending on whether they need a copy operation or
-    // they remain as operands of the return operation. If an operand is
-    // decomposable and a decompose callback function has been provided by the
-    // user, it will be unpacked.
-    SmallVector<Value, 2> newOperands, needCopyOperands;
-    OpBuilder builder(returnOp);
-    for (auto operand : llvm::enumerate(operands)) {
-      SmallVector<Value, 2> values;
-      this->converter->tryDecomposeValue(
-          builder, loc, operand.value().getType(), operand.value(), values);
-      Type type = returnOp.getOperand(operand.index()).getType();
-      SmallVector<Type, 2> originTypes;
-      this->converter->tryDecomposeType(type, originTypes);
-      for (auto value : llvm::enumerate(values)) {
-        Type origin = originTypes[value.index()];
-        Type converted = value.value().getType();
-        auto kind = this->converter->getResultConversionKind(origin, converted);
-        if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult)
-          newOperands.push_back(value.value());
-        else
-          // kind = BufferAssignmentTypeConverter::AppendToArgumentsList
-          needCopyOperands.push_back(value.value());
-      }
+    // If the memref typed results can be returned as function results, the new
+    // `ReturnOp` should only return the type converted operands.
+    if (allowMemrefFunctionResults) {
+      rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, operands);
+      return success();
     }
 
-    // Insert Copy operations instead for the operands that have been removed
-    // from operand list and appended to the function arguments list.
+    // Split the operands by their kinds whether they are converted memref or
+    // not.
+    SmallVector<Value, 2> needCopyOperands, newOperands;
+    unsigned operandsSize = operands.size();
+    needCopyOperands.reserve(operandsSize);
+    newOperands.reserve(operandsSize);
+    for (auto operand : llvm::enumerate(operands))
+      if (BufferAssignmentTypeConverter::isConvertedMemref(
+              operand.value().getType(),
+              returnOp.getOperand(operand.index()).getType()))
+        needCopyOperands.push_back(operand.value());
+      else
+        newOperands.push_back(operand.value());
+
     Block &entryBlock = returnOp.getParentRegion()->front();
     unsigned numFuncArgs = entryBlock.getNumArguments();
-    if (needCopyOperands.size() > numFuncArgs)
-      return returnOp.emitError(
-          "The number of operands that need Copy operations is more "
-          "than the number of target function arguments.");
+
+    // Find the index of the first destination buffer.
+    assert(needCopyOperands.size() <= numFuncArgs &&
+           "The number of operands of return operation is more than the "
+           "number of function arguments.");
     unsigned destArgNum = numFuncArgs - needCopyOperands.size();
     rewriter.setInsertionPoint(returnOp);
     for (Value operand : needCopyOperands) {
-      rewriter.create<CopyOpTy>(loc, operand,
+      // Insert a `CopyOp` for each converted memref-type operand.
+      rewriter.create<CopyOpTy>(returnOp.getLoc(), operand,
                                 entryBlock.getArgument(destArgNum));
       ++destArgNum;
     }
+
+    // Insert the new target Return operation.
     rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, newOperands);
     return success();
   }
@@ -265,32 +219,94 @@ class BufferAssignmentReturnOpConverter
 
 /// Rewrites the `CallOp` to match its operands and results with the signature
 /// of the callee after rewriting the callee with
-/// BufferAssignmentFuncOpConverter.
+/// BufferAssignmentFuncOpConverter. If allowMemrefFunctionResults is false, a
+/// buffer is allocated as an output buffer only for each memref typed result
+/// that has been rewritten. The new allocated buffer is passed through the
+/// operands list of the new `CallOp`.
+/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
+/// allowMemrefFunctionResults must be set/unset for both.
+template <bool allowMemrefFunctionResults>
 class BufferAssignmentCallOpConverter
     : public BufferAssignmentOpConversionPattern<CallOp> {
 public:
   using BufferAssignmentOpConversionPattern<
       CallOp>::BufferAssignmentOpConversionPattern;
 
-  /// Performs the actual rewriting step.
-  LogicalResult matchAndRewrite(CallOp, ArrayRef<Value>,
-                                ConversionPatternRewriter &) const;
+  LogicalResult
+  matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    if (!converter)
+      return callOp.emitError("The type converter has not been defined for "
+                              "BufferAssignmentCallOpConverter");
+    Location loc = callOp.getLoc();
+
+    // If the memref typed results can be returned as function results, there is
+    // no need to create output buffers. It is only required to convert the type
+    // of operands and results in place for creating the new `CallOp`.
+    if (allowMemrefFunctionResults) {
+      SmallVector<Type, 2> resultTypes;
+      resultTypes.reserve(callOp.getNumResults());
+      for (Type type : callOp.getResultTypes())
+        resultTypes.push_back(converter->convertType(type));
+      rewriter.replaceOpWithNewOp<CallOp>(callOp, callOp.getCallee(),
+                                          resultTypes, operands);
+      return success();
+    }
+
+    SmallVector<Value, 2> newOperands, replacingValues;
+    SmallVector<Type, 2> newResultTypes;
+    unsigned numResults = callOp.getNumResults();
+    newOperands.reserve(numResults + operands.size());
+    newOperands.append(operands.begin(), operands.end());
+    newResultTypes.reserve(numResults);
+    replacingValues.reserve(numResults);
+
+    // For each memref result of `CallOp` which has not been a memref before
+    // the type conversion, a new buffer is allocated and passed to the operands
+    // list of the new `CallOp`. Otherwise, it remains as a caller result.
+    for (Value result : callOp.getResults()) {
+      Type currType = result.getType();
+      Type newType = converter->convertType(result.getType());
+      if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) {
+        OpBuilder::InsertionGuard guard(rewriter);
+        rewriter.restoreInsertionPoint(bufferAssignment->computeAllocPosition(
+            result.dyn_cast<OpResult>()));
+        Value alloc =
+            rewriter.create<AllocOp>(loc, newType.dyn_cast<MemRefType>());
+        newOperands.push_back(alloc);
+        replacingValues.push_back(alloc);
+      } else {
+        newResultTypes.push_back(currType);
+
+        // No replacing is required.
+        replacingValues.push_back(nullptr);
+      }
+    }
+
+    // Creating the new `CallOp`.
+    rewriter.create<CallOp>(loc, callOp.getCallee(), newResultTypes,
+                            newOperands);
+
+    // Replacing the results of the old `CallOp`.
+    rewriter.replaceOp(callOp, replacingValues);
+    return success();
+  }
 };
+} // end namespace detail
 
 /// Populates `patterns` with the conversion patterns of buffer
 /// assignment.
 template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
-          typename CopyOpTy>
+          typename CopyOpTy, bool allowMemrefFunctionResults>
 static void populateWithBufferAssignmentOpConversionPatterns(
     MLIRContext *context, BufferAssignmentPlacer *placer,
-    BufferAssignmentTypeConverter *converter,
-    OwningRewritePatternList *patterns) {
+    TypeConverter *converter, OwningRewritePatternList *patterns) {
   // clang-format off
   patterns->insert<
-    BufferAssignmentCallOpConverter,
-    BufferAssignmentFuncOpConverter,
-    BufferAssignmentReturnOpConverter
-      <ReturnOpSourceTy, ReturnOpTargetTy, CopyOpTy>
+    detail::BufferAssignmentCallOpConverter<allowMemrefFunctionResults>,
+    detail::BufferAssignmentFuncOpConverter<allowMemrefFunctionResults>,
+    detail::BufferAssignmentReturnOpConverter
+      <ReturnOpSourceTy, ReturnOpTargetTy, CopyOpTy, allowMemrefFunctionResults>
   >(context, placer, converter);
   // clang-format on
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index 89a01f9ca629..04c1fbd5d565 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -100,11 +100,11 @@ class GenericOpConverter
 /// tensors to buffers.
 static void populateConvertLinalgOnTensorsToBuffersPattern(
     MLIRContext *context, BufferAssignmentPlacer *placer,
-    BufferAssignmentTypeConverter *converter,
-    OwningRewritePatternList *patterns) {
+    TypeConverter *converter, OwningRewritePatternList *patterns) {
   populateWithBufferAssignmentOpConversionPatterns<
-      mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer,
-                                                      converter, patterns);
+      mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
+      /*allowMemrefFunctionResults=*/false>(context, placer, converter,
+                                            patterns);
   patterns->insert<GenericOpConverter>(context, placer, converter);
 }
 
@@ -141,9 +141,6 @@ struct ConvertLinalgOnTensorsToBuffers
              converter.isLegal(&funcOp.getBody());
     });
 
-    converter.setResultConversionKind<RankedTensorType, MemRefType>(
-        BufferAssignmentTypeConverter::AppendToArgumentsList);
-
     // Walk over all the functions to apply buffer assignment.
     getOperation().walk([&](FuncOp function) -> WalkResult {
       OwningRewritePatternList patterns;

diff  --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp
index 1ab3e7e2e48d..201570a244ff 100644
--- a/mlir/lib/Transforms/BufferPlacement.cpp
+++ b/mlir/lib/Transforms/BufferPlacement.cpp
@@ -713,223 +713,9 @@ BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() {
   });
 }
 
-/// This method tries to decompose a value of a certain type using provided
-/// decompose callback functions. If it is unable to do so, the original value
-/// is returned.
-void BufferAssignmentTypeConverter::tryDecomposeValue(
-    OpBuilder &builder, Location loc, Type type, Value value,
-    SmallVectorImpl<Value> &results) {
-  for (auto conversion : decomposeValueConversions)
-    if (conversion(builder, loc, type, value, results) != llvm::None)
-      return;
-  results.push_back(value);
-}
-
-/// This method tries to decompose a type using provided decompose callback
-/// functions. If it is unable to do so, the original type is returned.
-void BufferAssignmentTypeConverter::tryDecomposeType(
-    Type type, SmallVectorImpl<Type> &types) {
-  for (auto conversion : decomposeTypeConversions)
-    if (conversion(type, types) != llvm::None)
-      return;
-  types.push_back(type);
-}
-
-/// This method returns ResultConversionKind for the input type.
-BufferAssignmentTypeConverter::ResultConversionKind
-BufferAssignmentTypeConverter::getResultConversionKind(Type origin,
-                                                       Type converted) {
-  for (auto conversion : resultTypeConversions) {
-    auto res = conversion(origin, converted);
-    if (res != llvm::None)
-      return res.getValue();
-  }
-  return KeepAsFunctionResult;
-}
-
-//===----------------------------------------------------------------------===//
-// BufferAssignmentFuncOpConverter
-//===----------------------------------------------------------------------===//
-
-/// Performs the actual function signature rewriting step.
-LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite(
-    mlir::FuncOp funcOp, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-  auto funcType = funcOp.getType();
-
-  // Convert function arguments using the provided TypeConverter.
-  TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
-  for (auto argType : llvm::enumerate(funcType.getInputs())) {
-    SmallVector<Type, 2> decomposedTypes, convertedTypes;
-    converter->tryDecomposeType(argType.value(), decomposedTypes);
-    converter->convertTypes(decomposedTypes, convertedTypes);
-    conversion.addInputs(argType.index(), convertedTypes);
-  }
-
-  // Convert the result types of the function.
-  SmallVector<Type, 2> newResultTypes;
-  newResultTypes.reserve(funcOp.getNumResults());
-  for (Type resultType : funcType.getResults()) {
-    SmallVector<Type, 2> originTypes;
-    converter->tryDecomposeType(resultType, originTypes);
-    for (auto origin : originTypes) {
-      Type converted = converter->convertType(origin);
-      auto kind = converter->getResultConversionKind(origin, converted);
-      if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList)
-        conversion.addInputs(converted);
-      else
-        // kind = BufferAssignmentTypeConverter::KeepAsFunctionResult
-        newResultTypes.push_back(converted);
-    }
-  }
-
-  if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter,
-                                         &conversion)))
-    return failure();
-
-  // Update the signature of the function.
-  rewriter.updateRootInPlace(funcOp, [&] {
-    funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
-                                            newResultTypes));
-  });
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// BufferAssignmentCallOpConverter
-//===----------------------------------------------------------------------===//
-
-/// Performs the actual rewriting step.
-LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
-    CallOp callOp, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-
-  // This class represents a mapping from a result to a list of values and some
-  // results that have not yet constructed. Instead, the indices of these
-  // results in the operation that will be constructed are known. They will be
-  // replaced with the actual values when they are available. The order of
-  // adding to this mapping is important.
-  class ResultMapping {
-  public:
-    ResultMapping() { order = 0; };
-
-    /// Add an available value to the mapping.
-    void addMapping(Value value) {
-      toValuesMapping.push_back({order++, value});
-    }
-
-    /// Add the index of unavailble result value to the mapping.
-    void addMapping(unsigned index) {
-      toIndicesMapping.push_back({order++, index});
-    }
-
-    /// This method returns the mapping values list. The unknown result values
-    /// that only their indicies are available are replaced with their values.
-    void getMappingValues(ValueRange valuesToReplaceIndices,
-                          SmallVectorImpl<Value> &values) {
-      // Append available values to the list.
-      SmallVector<std::pair<unsigned, Value>, 2> res(toValuesMapping.begin(),
-                                                     toValuesMapping.end());
-      // Replace the indices with the actual values.
-      llvm::for_each(
-          toIndicesMapping, [&](const std::pair<unsigned, unsigned> &entry) {
-            assert(entry.second < valuesToReplaceIndices.size() &&
-                   "The value index is out of range.");
-            res.push_back({entry.first, valuesToReplaceIndices[entry.second]});
-          });
-      // Sort the values based on their adding orders.
-      llvm::sort(res, [](const std::pair<unsigned, Value> &v1,
-                         const std::pair<unsigned, Value> &v2) {
-        return v1.first < v2.first;
-      });
-      // Fill the values.
-      llvm::for_each(res, [&](const std::pair<unsigned, Value> &entry) {
-        values.push_back(entry.second);
-      });
-    }
-
-  private:
-    /// Keeping the inserting order of mapping values.
-    int order;
-
-    /// Containing the mapping values with their inserting orders.
-    SmallVector<std::pair<unsigned, Value>, 2> toValuesMapping;
-
-    /// Containing the indices of result values with their inserting orders.
-    SmallVector<std::pair<unsigned, unsigned>, 2> toIndicesMapping;
-  };
-
-  Location loc = callOp.getLoc();
-  OpBuilder builder(callOp);
-  SmallVector<Value, 2> newOperands;
-
-  // Create the operands list of the new `CallOp`. It unpacks the decomposable
-  // values if a decompose callback function has been provided by the user.
-  for (auto operand : operands) {
-    SmallVector<Value, 2> values;
-    this->converter->tryDecomposeValue(builder, loc, operand.getType(), operand,
-                                       values);
-    newOperands.append(values.begin(), values.end());
-  }
-
-  // Create the new result types for the new `CallOp` and a mapping from the old
-  // result to new value(s).
-  SmallVector<Type, 2> newResultTypes;
-  SmallVector<ResultMapping, 4> mappings;
-  mappings.resize(callOp.getNumResults());
-  for (auto result : llvm::enumerate(callOp.getResults())) {
-    SmallVector<Type, 2> originTypes;
-    converter->tryDecomposeType(result.value().getType(), originTypes);
-    auto &resultMapping = mappings[result.index()];
-    for (Type origin : originTypes) {
-      Type converted = converter->convertType(origin);
-      auto kind = converter->getResultConversionKind(origin, converted);
-      if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) {
-        newResultTypes.push_back(converted);
-        // The result value is not yet available. Its index is kept and it is
-        // replaced with the actual value of the new `CallOp` later.
-        resultMapping.addMapping(newResultTypes.size() - 1);
-      } else {
-        // kind = BufferAssignmentTypeConverter::AppendToArgumentsList
-        OpBuilder::InsertionGuard guard(rewriter);
-        rewriter.restoreInsertionPoint(
-            bufferAssignment->computeAllocPosition(result.value()));
-        MemRefType memref = converted.dyn_cast<MemRefType>();
-        if (!memref)
-          return callOp.emitError("Cannot allocate for a non-Memref type");
-        Value alloc = rewriter.create<AllocOp>(loc, memref);
-        newOperands.push_back(alloc);
-        resultMapping.addMapping(alloc);
-      }
-    }
-  }
-
-  CallOp newCallOp = rewriter.create<CallOp>(loc, callOp.getCallee(),
-                                             newResultTypes, newOperands);
-
-  // Build a replacing value for each result to replace its uses. If a result
-  // has multiple mapping values, it needs to be packed to a single value.
-  OpBuilder nextBuilder(callOp.getOperation()->getNextNode());
-  SmallVector<Value, 2> replacedValues;
-  replacedValues.reserve(callOp.getNumResults());
-  for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) {
-    SmallVector<Value, 2> valuesToPack;
-    mappings[i].getMappingValues(newCallOp.getResults(), valuesToPack);
-    if (valuesToPack.empty()) {
-      // No replacement is required.
-      replacedValues.push_back(nullptr);
-    } else if (valuesToPack.size() == 1) {
-      replacedValues.push_back(valuesToPack.front());
-    } else {
-      // Values need to be packed using callback function. The same callback
-      // that is used for materializeArgumentConversion is used for packing.
-      Value packed = converter->materializeArgumentConversion(
-          nextBuilder, loc, callOp.getType(i), valuesToPack);
-      replacedValues.push_back(packed);
-    }
-  }
-  rewriter.replaceOp(callOp, replacedValues);
-  return success();
+/// Checks if `type` has been converted from non-memref type to memref.
+bool BufferAssignmentTypeConverter::isConvertedMemref(Type type, Type before) {
+  return type.isa<BaseMemRefType>() && !before.isa<BaseMemRefType>();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
index e1dacdf0184e..084ac38af6e3 100644
--- a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
+++ b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
@@ -111,73 +111,7 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
 // CHECK: %[[Y:.*]]:2 = call @callee(%[[X]]#0)
 // CHECK: return %[[Y]]#0
 
-// -----
-
-// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the
-// signature of the new signature of the callee function when there are tuple typed
-// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed
-// arguments. The tuple typed values should be decomposed and composed using
-// get_tuple_element and make_tuple operations of test dialect. Tensor types are
-// converted to Memref. Memref typed function results remain as function results.
 
-// CHECK-LABEL: func @callee
-func @callee(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>){
-  return %arg0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
-}
-// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>)
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>)
-// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]]  = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]
 
-// CHECK-LABEL: func @caller
-func @caller(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> tuple<tensor<2xf32>,i1, tensor<5xf32>>{
-  %x0 = call @callee(%arg0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
-  %y0 = call @callee(%x0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
-  return %y0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
-}
-// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>)
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>)
-// CHECK-NEXT: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]]  = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]])
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>)
-// CHECK-NEXT: %[[RESULT_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2)
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]]  = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]])
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>)
-// CHECK-NEXT: %[[RETURN_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2)
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]]  = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]
 
-// -----
 
-// Test case: Testing BufferAssginmnetFuncOpConverter and
-// BufferAssginmentReturnOpConverter to see if the return operation matches with
-// the new function signature when there are tuple typed args and results.
-// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple
-// typed values should be decomposed and composed using get_tuple_element and
-// make_tuple operations of test dialect. Tensor types are converted to Memref.
-// Memref typed function results remain as function results.
-
-// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results
-func @decompose_tuple_typed_function_args_and_results(%arg0: tuple<i1,f32>, %arg1: tensor<10xf32>, %arg2: tuple<i1, tensor<5xf32>>) -> (tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>){
-  return %arg2, %arg1, %arg0 : tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>
-}
-// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>
-// CHECK-SAME: (i1, memref<5xf32>, memref<10xf32>, i1, f32)
-// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]])
-// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]])
-// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]]  = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[SECOND_TUPLE_SECOND_ELEM]], %[[ARG2]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]]

diff  --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir
index b1cfdfd690cf..064b0fd7e85a 100644
--- a/mlir/test/Transforms/buffer-placement-preparation.mlir
+++ b/mlir/test/Transforms/buffer-placement-preparation.mlir
@@ -285,93 +285,8 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
 // CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]])
 // CHECK: return
 
-// -----
-
 // CHECK-LABEL: func @func_with_unranked_arg
 func @func_with_unranked_arg(%arg0: tensor<*xf32>) {
   return
 }
 // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
-
-// -----
-
-// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the
-// signature of the new signature of the callee function when there are tuple typed
-// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed
-// arguments. The tuple typed values should be decomposed and composed using
-// get_tuple_element and make_tuple operations of test dialect. Tensor types are
-// converted to Memref. Memref typed function results are appended to the function
-// arguments list.
-
-// CHECK-LABEL: func @callee
-func @callee(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>){
-  return %arg0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
-}
-// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>)
-// CHECK-SAME: i1
-// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]]  = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]])
-// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]])
-// CHECK-NEXT: return %[[SECOND_ELEM]]
-
-
-// CHECK-LABEL: func @caller
-func @caller(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> tuple<tensor<2xf32>,i1, tensor<5xf32>>{
-  %x0 = call @callee(%arg0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
-  %y0 = call @callee(%x0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
-  return %y0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
-}
-// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>)
-// CHECK-SAME: i1
-// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]]  = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
-// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc()
-// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]])
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1
-// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]])
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]]  = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
-// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc()
-// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]])
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1
-// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]])
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]]  = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]])
-// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]])
-// CHECK-NEXT: return %[[SECOND_ELEM]]
-
-// -----
-
-// Test case: Testing BufferAssginmnetFuncOpConverter and
-// BufferAssginmentReturnOpConverter to see if the return operation matches with
-// the new function signature when there are tuple typed args and results.
-// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple
-// typed values should be decomposed and composed using get_tuple_element and
-// make_tuple operations of test dialect. Tensor types are converted to Memref.
-// Memref typed function results are appended to the function arguments list.
-
-// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results
-func @decompose_tuple_typed_function_args_and_results(%arg0: tuple<i1,f32>, %arg1: tensor<10xf32>, %arg2: tuple<i1, tensor<5xf32>>) -> (tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>){
-  return %arg2, %arg1, %arg0 : tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>
-}
-// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<5xf32>, %[[RESULT1:.*]]: memref<10xf32>
-// CHECK-SAME: (i1, i1, f32)
-// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]])
-// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]])
-// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]]  = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: linalg.copy(%[[SECOND_TUPLE_SECOND_ELEM]], %[[RESULT0]])
-// CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]])
-// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]]

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index f03c953396a4..bc26a8659831 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1669,7 +1669,7 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
   let results = (outs AnyType:$result);
 
   let extraClassDeclaration = [{
-    static LogicalResult inferReturnTypes(MLIRContext *,
+    static LogicalResult inferReturnTypes(MLIRContext *, 
           Optional<Location> location, ValueRange operands,
           DictionaryAttr attributes, RegionRange regions,
           SmallVectorImpl<Type> &inferredReturnTypes) {
@@ -1679,31 +1679,4 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
    }];
 }
 
-//===----------------------------------------------------------------------===//
-// Test BufferPlacement
-//===----------------------------------------------------------------------===//
-
-def GetTupleElementOp: TEST_Op<"get_tuple_element"> {
-  let description = [{
-    Test op that returns a specified element of the tuple.
-  }];
-
-  let arguments = (ins
-    TupleOf<[AnyType]>,
-    I32Attr:$index
-  );
-  let results = (outs AnyType);
-}
-
-def MakeTupleOp: TEST_Op<"make_tuple"> {
-  let description = [{
-    Test op that creates a tuple value from a list of values.
-  }];
-
-  let arguments = (ins
-    Variadic<AnyType>:$inputs
-  );
-  let results = (outs TupleOf<[AnyType]>);
-}
-
 #endif // TEST_OPS

diff  --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index 14b72b9fc92a..6cc0924191cb 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -11,8 +11,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "TestDialect.h"
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Operation.h"
@@ -111,16 +109,14 @@ struct TestBufferPlacementPreparationPass
 
   void populateTensorLinalgToBufferLinalgConversionPattern(
       MLIRContext *context, BufferAssignmentPlacer *placer,
-      BufferAssignmentTypeConverter *converter,
-      OwningRewritePatternList *patterns) {
+      TypeConverter *converter, OwningRewritePatternList *patterns) {
     populateWithBufferAssignmentOpConversionPatterns<
-        mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer,
-                                                        converter, patterns);
+        mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
+        allowMemrefFunctionResults>(context, placer, converter, patterns);
     patterns->insert<GenericOpConverter>(context, placer, converter);
   }
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<TestDialect>();
     registry.insert<linalg::LinalgDialect>();
   }
 
@@ -131,8 +127,6 @@ struct TestBufferPlacementPreparationPass
 
     // Mark all Standard operations legal.
     target.addLegalDialect<StandardOpsDialect>();
-    target.addLegalOp<MakeTupleOp>();
-    target.addLegalOp<GetTupleElementOp>();
 
     // Mark all Linalg operations illegal as long as they work on tensors.
     auto isLegalOperation = [&](Operation *op) {
@@ -155,42 +149,6 @@ struct TestBufferPlacementPreparationPass
              converter.isLegal(&funcOp.getBody());
     });
 
-    auto kind = allowMemrefFunctionResults
-                    ? BufferAssignmentTypeConverter::KeepAsFunctionResult
-                    : BufferAssignmentTypeConverter::AppendToArgumentsList;
-    converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
-    converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
-        kind);
-
-    converter.addDecomposeTypeConversion(
-        [](TupleType tupleType, SmallVectorImpl<Type> &types) {
-          tupleType.getFlattenedTypes(types);
-          return success();
-        });
-
-    converter.addArgumentMaterialization(
-        [](OpBuilder &builder, TupleType resultType, ValueRange inputs,
-           Location loc) -> Optional<Value> {
-          if (inputs.size() == 1)
-            return llvm::None;
-          TypeRange TypeRange = inputs.getTypes();
-          SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end());
-          TupleType tuple = TupleType::get(types, builder.getContext());
-          mlir::Value value = builder.create<MakeTupleOp>(loc, tuple, inputs);
-          return value;
-        });
-
-    converter.addDecomposeValueConversion([](OpBuilder &builder, Location loc,
-                                             TupleType resultType, Value value,
-                                             SmallVectorImpl<Value> &values) {
-      for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
-        Value res = builder.create<GetTupleElementOp>(
-            loc, resultType.getType(i), value, builder.getI32IntegerAttr(i));
-        values.push_back(res);
-      }
-      return success();
-    });
-
     // Walk over all the functions to apply buffer assignment.
     this->getOperation().walk([&](FuncOp function) -> WalkResult {
       OwningRewritePatternList patterns;


        


More information about the Mlir-commits mailing list