[Mlir-commits] [mlir] 7c62c63 - [mlir] Add DecomposeCallGraphTypes pass.

Sean Silva llvmlistbot at llvm.org
Mon Nov 16 12:30:31 PST 2020


Author: Sean Silva
Date: 2020-11-16T12:25:35-08:00
New Revision: 7c62c6313baebb4866dd51a095c66c7808af868b

URL: https://github.com/llvm/llvm-project/commit/7c62c6313baebb4866dd51a095c66c7808af868b
DIFF: https://github.com/llvm/llvm-project/commit/7c62c6313baebb4866dd51a095c66c7808af868b.diff

LOG: [mlir] Add DecomposeCallGraphTypes pass.

This replaces the old type decomposition logic that was previously mixed
into bufferization, and makes it easily accessible.

This also deletes TestFinalizingBufferize, because after we remove the type
decomposition, it doesn't do anything that is not already provided by
func-bufferize.

Differential Revision: https://reviews.llvm.org/D90899

Added: 
    mlir/include/mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h
    mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp
    mlir/test/Transforms/decompose-call-graph-types.mlir
    mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp

Modified: 
    mlir/include/mlir/Transforms/Bufferize.h
    mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
    mlir/lib/Transforms/Bufferize.cpp
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    mlir/test/Transforms/finalizing-bufferize.mlir
    mlir/test/lib/Transforms/TestFinalizingBufferize.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h
new file mode 100644
index 000000000000..49895acd9d24
--- /dev/null
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h
@@ -0,0 +1,90 @@
+//===- DecomposeCallGraphTypes.h - CG type decompositions -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Conversion patterns for decomposing types along call graph edges. That is,
+// decomposing types for calls, returns, and function args.
+//
+// TODO: Make this handle dialect-defined functions, calls, and returns.
+// Currently, the generic interfaces aren't sophisticated enough for the
+// types of mutations that we are doing here.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
+#define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+
+/// This class provides a hook that expands one Value into multiple Value's,
+/// with a TypeConverter-inspired callback registration mechanism.
+///
+/// For folks that are familiar with the dialect conversion framework /
+/// TypeConverter, this is effectively the inverse of a source/argument
+/// materialization. A target materialization is not what we want here because
+/// it always produces a single Value, but in this case the whole point is to
+/// decompose a Value into multiple Value's.
+///
+/// The reason we need this inverse is easily understood by looking at what we
+/// need to do for decomposing types for a return op. When converting a return
+/// op, the dialect conversion framework will give the list of converted
+/// operands, and will ensure that each converted operand, even if it expanded
+/// into multiple types, is materialized as a single result. We then need to
+/// undo that materialization to a single result, which we do with the
+/// decomposeValue hooks registered on this object.
+///
+/// TODO: Eventually, the type conversion infra should have this hook built-in.
+/// See
+/// https://llvm.discourse.group/t/extending-type-conversion-infrastructure/779/2
+class ValueDecomposer {
+public:
+  /// This method tries to decompose a value of a certain type using provided
+  /// decompose callback functions. If it is unable to do so, the original value
+  /// is returned.
+  void decomposeValue(OpBuilder &, Location, Type, Value,
+                      SmallVectorImpl<Value> &);
+
+  /// This method registers a callback function that will be called to decompose
+  /// a value of a certain type into 0, 1, or multiple values.
+  template <typename FnT,
+            typename T = typename llvm::function_traits<FnT>::template arg_t<2>>
+  void addDecomposeValueConversion(FnT &&callback) {
+    decomposeValueConversions.emplace_back(
+        wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
+  }
+
+private:
+  using DecomposeValueConversionCallFn = std::function<Optional<LogicalResult>(
+      OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>;
+
+  /// Generate a wrapper for the given decompose value conversion callback.
+  template <typename T, typename FnT>
+  DecomposeValueConversionCallFn
+  wrapDecomposeValueConversionCallback(FnT &&callback) {
+    return [callback = std::forward<FnT>(callback)](
+               OpBuilder &builder, Location loc, Type type, Value value,
+               SmallVectorImpl<Value> &newValues) -> Optional<LogicalResult> {
+      if (T derivedType = type.dyn_cast<T>())
+        return callback(builder, loc, derivedType, value, newValues);
+      return llvm::None;
+    };
+  }
+
+  SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions;
+};
+
+/// Populates the patterns needed to drive the conversion process for
+/// decomposing call graph types with the given `ValueDecomposer`.
+void populateDecomposeCallGraphTypesPatterns(
+    MLIRContext *context, TypeConverter &typeConverter,
+    ValueDecomposer &decomposer, OwningRewritePatternList &patterns);
+
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H

diff  --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h
index 3434be4214a7..1fde1d2ebce1 100644
--- a/mlir/include/mlir/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Transforms/Bufferize.h
@@ -15,13 +15,8 @@
 //
 // Bufferization conversion patterns should generally use the ordinary
 // conversion pattern classes (e.g. OpConversionPattern). A TypeConverter
-// (accessible with getTypeConverter()) available on such patterns is sufficient
-// for most cases (if needed at all).
-//
-// But some patterns require access to the extra functions on
-// BufferizeTypeConverter that don't exist on the base TypeConverter class. For
-// those cases, BufferizeConversionPattern and its related classes should be
-// used, which provide access to a BufferizeTypeConverter directly.
+// (accessible with getTypeConverter()) is available if needed for converting
+// types.
 //
 //===----------------------------------------------------------------------===//
 
@@ -39,79 +34,11 @@
 
 namespace mlir {
 
-/// A helper type converter class for using inside Buffer Assignment operation
-/// conversion patterns. The default constructor keeps all the types intact
-/// except for the ranked-tensor types which is converted to memref types.
+/// A helper type converter class that automatically populates the relevant
+/// materializations and type conversions for bufferization.
 class BufferizeTypeConverter : public TypeConverter {
 public:
   BufferizeTypeConverter();
-
-  /// This method tries to decompose a value of a certain type using provided
-  /// decompose callback functions. If it is unable to do so, the original value
-  /// is returned.
-  void tryDecomposeValue(OpBuilder &, Location, Type, Value,
-                         SmallVectorImpl<Value> &);
-
-  /// This method tries to decompose a type using provided decompose callback
-  /// functions. If it is unable to do so, the original type is returned.
-  void tryDecomposeType(Type, SmallVectorImpl<Type> &);
-
-  /// This method registers a callback function that will be called to decompose
-  /// a value of a certain type into several values.
-  template <typename FnT,
-            typename T = typename llvm::function_traits<FnT>::template arg_t<2>>
-  void addDecomposeValueConversion(FnT &&callback) {
-    decomposeValueConversions.emplace_back(
-        wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
-  }
-
-  /// This method registers a callback function that will be called to decompose
-  /// a type into several types.
-  template <typename FnT,
-            typename T = typename llvm::function_traits<FnT>::template arg_t<0>>
-  void addDecomposeTypeConversion(FnT &&callback) {
-    auto wrapper =
-        wrapDecomposeTypeConversionCallback<T>(std::forward<FnT>(callback));
-    decomposeTypeConversions.emplace_back(wrapper);
-    addConversion(std::forward<FnT>(callback));
-  }
-
-private:
-  using DecomposeValueConversionCallFn = std::function<Optional<LogicalResult>(
-      OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>;
-
-  using DecomposeTypeConversionCallFn =
-      std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
-
-  /// Generate a wrapper for the given decompose value conversion callback.
-  template <typename T, typename FnT>
-  DecomposeValueConversionCallFn
-  wrapDecomposeValueConversionCallback(FnT &&callback) {
-    return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Location loc, Type type, Value value,
-               SmallVectorImpl<Value> &newValues) -> Optional<LogicalResult> {
-      if (T derivedType = type.dyn_cast<T>())
-        return callback(builder, loc, derivedType, value, newValues);
-      return llvm::None;
-    };
-  }
-
-  /// Generate a wrapper for the given decompose type conversion callback.
-  template <typename T, typename FnT>
-  DecomposeTypeConversionCallFn
-  wrapDecomposeTypeConversionCallback(FnT &&callback) {
-    return [callback = std::forward<FnT>(callback)](
-               Type type,
-               SmallVectorImpl<Type> &results) -> Optional<LogicalResult> {
-      T derivedType = type.dyn_cast<T>();
-      if (!derivedType)
-        return llvm::None;
-      return callback(derivedType, results);
-    };
-  }
-
-  SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions;
-  SmallVector<DecomposeTypeConversionCallFn, 2> decomposeTypeConversions;
 };
 
 /// Marks ops used by bufferization for type conversion materializations as
@@ -132,104 +59,6 @@ void populateEliminateBufferizeMaterializationsPatterns(
     MLIRContext *context, BufferizeTypeConverter &typeConverter,
     OwningRewritePatternList &patterns);
 
-/// Helper conversion pattern that encapsulates a BufferizeTypeConverter
-/// instance.
-template <typename SourceOp>
-class BufferizeOpConversionPattern : public OpConversionPattern<SourceOp> {
-public:
-  explicit BufferizeOpConversionPattern(MLIRContext *context,
-                                        BufferizeTypeConverter &converter,
-                                        PatternBenefit benefit = 1)
-      : OpConversionPattern<SourceOp>(context, benefit), converter(converter) {}
-
-protected:
-  BufferizeTypeConverter &converter;
-};
-
-/// Helper conversion pattern that encapsulates a BufferizeTypeConverter
-/// instance and that operates on Operation* to be compatible with OpInterfaces.
-/// This allows avoiding to instantiate N patterns for ops that can be subsumed
-/// by a single op interface (e.g. Linalg named ops).
-class BufferizeConversionPattern : public ConversionPattern {
-public:
-  explicit BufferizeConversionPattern(MLIRContext *context,
-                                      BufferizeTypeConverter &converter,
-                                      PatternBenefit benefit = 1)
-      : ConversionPattern(benefit, converter, MatchAnyOpTypeTag()),
-        converter(converter) {}
-
-protected:
-  BufferizeTypeConverter &converter;
-};
-
-/// Converts the signature of the function using BufferizeTypeConverter.
-/// Each result type of the function is kept as a function result or appended to
-/// the function arguments list based on ResultConversionKind for the converted
-/// result type.
-class BufferizeFuncOpConverter : public BufferizeOpConversionPattern<FuncOp> {
-public:
-  using BufferizeOpConversionPattern<FuncOp>::BufferizeOpConversionPattern;
-
-  /// Performs the actual signature rewriting step.
-  LogicalResult matchAndRewrite(mlir::FuncOp, ArrayRef<Value>,
-                                ConversionPatternRewriter &) const override;
-};
-
-/// Rewrites the `ReturnOp` to conform with the changed function signature.
-/// Operands that correspond to return values and their types have been set to
-/// AppendToArgumentsList are dropped. In their place, a corresponding copy
-/// operation from the operand to the target function argument is inserted.
-template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
-          typename CopyOpTy>
-class BufferizeReturnOpConverter
-    : public BufferizeOpConversionPattern<ReturnOpSourceTy> {
-public:
-  using BufferizeOpConversionPattern<
-      ReturnOpSourceTy>::BufferizeOpConversionPattern;
-
-  /// Performs the actual return-op conversion step.
-  LogicalResult
-  matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const final {
-    SmallVector<Value, 2> newOperands;
-    for (auto operand : operands)
-      this->converter.tryDecomposeValue(
-          rewriter, returnOp.getLoc(), operand.getType(), operand, newOperands);
-    rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, newOperands);
-    return success();
-  }
-};
-
-/// Rewrites the `CallOp` to match its operands and results with the signature
-/// of the callee after rewriting the callee with
-/// BufferizeFuncOpConverter.
-class BufferizeCallOpConverter : public BufferizeOpConversionPattern<CallOp> {
-public:
-  using BufferizeOpConversionPattern<CallOp>::BufferizeOpConversionPattern;
-
-  /// Performs the actual rewriting step.
-  LogicalResult matchAndRewrite(CallOp, ArrayRef<Value>,
-                                ConversionPatternRewriter &) const override;
-};
-
-/// Populates `patterns` with the conversion patterns of buffer
-/// assignment.
-template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
-          typename CopyOpTy>
-static void
-populateWithBufferizeOpConversionPatterns(MLIRContext *context,
-                                          BufferizeTypeConverter &converter,
-                                          OwningRewritePatternList &patterns) {
-  // clang-format off
-  patterns.insert<
-    BufferizeCallOpConverter,
-    BufferizeFuncOpConverter,
-    BufferizeReturnOpConverter
-      <ReturnOpSourceTy, ReturnOpTargetTy, CopyOpTy>
-  >(context, converter);
-  // clang-format on
-}
-
 /// A simple analysis that detects allocation operations.
 class BufferPlacementAllocs {
 public:

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
index ce5494cf855b..0465ff1b2cf0 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRStandardOpsTransforms
   Bufferize.cpp
+  DecomposeCallGraphTypes.cpp
   ExpandOps.cpp
   ExpandTanh.cpp
   FuncBufferize.cpp

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp
new file mode 100644
index 000000000000..fdd73b21237a
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp
@@ -0,0 +1,192 @@
+//===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Function.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// ValueDecomposer
+//===----------------------------------------------------------------------===//
+
+void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc,
+                                     Type type, Value value,
+                                     SmallVectorImpl<Value> &results) {
+  for (auto &conversion : decomposeValueConversions)
+    if (conversion(builder, loc, type, value, results))
+      return;
+  results.push_back(value);
+}
+
+//===----------------------------------------------------------------------===//
+// DecomposeCallGraphTypesOpConversionPattern
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Base OpConversionPattern class to make a ValueDecomposer available to
+/// inherited patterns.
+template <typename SourceOp>
+class DecomposeCallGraphTypesOpConversionPattern
+    : public OpConversionPattern<SourceOp> {
+public:
+  DecomposeCallGraphTypesOpConversionPattern(TypeConverter &typeConverter,
+                                             MLIRContext *context,
+                                             ValueDecomposer &decomposer,
+                                             PatternBenefit benefit = 1)
+      : OpConversionPattern<SourceOp>(typeConverter, context, benefit),
+        decomposer(decomposer) {}
+
+protected:
+  ValueDecomposer &decomposer;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// DecomposeCallGraphTypesForFuncArgs
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Expand function arguments according to the provided TypeConverter and
+/// ValueDecomposer.
+struct DecomposeCallGraphTypesForFuncArgs
+    : public DecomposeCallGraphTypesOpConversionPattern<FuncOp> {
+  using DecomposeCallGraphTypesOpConversionPattern::
+      DecomposeCallGraphTypesOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(FuncOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto functionType = op.getType();
+
+    // Convert function arguments using the provided TypeConverter.
+    TypeConverter::SignatureConversion conversion(functionType.getNumInputs());
+    for (auto argType : llvm::enumerate(functionType.getInputs())) {
+      SmallVector<Type, 2> decomposedTypes;
+      getTypeConverter()->convertType(argType.value(), decomposedTypes);
+      if (!decomposedTypes.empty())
+        conversion.addInputs(argType.index(), decomposedTypes);
+    }
+
+    // If the SignatureConversion doesn't apply, bail out.
+    if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(),
+                                           &conversion)))
+      return failure();
+
+    // Update the signature of the function.
+    SmallVector<Type, 2> newResultTypes;
+    getTypeConverter()->convertTypes(functionType.getResults(), newResultTypes);
+    rewriter.updateRootInPlace(op, [&] {
+      op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
+                                          newResultTypes));
+    });
+    return success();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// DecomposeCallGraphTypesForReturnOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Expand return operands according to the provided TypeConverter and
+/// ValueDecomposer.
+struct DecomposeCallGraphTypesForReturnOp
+    : public DecomposeCallGraphTypesOpConversionPattern<ReturnOp> {
+  using DecomposeCallGraphTypesOpConversionPattern::
+      DecomposeCallGraphTypesOpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    SmallVector<Value, 2> newOperands;
+    for (Value operand : operands)
+      decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
+                                operand, newOperands);
+    rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
+    return success();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// DecomposeCallGraphTypesForCallOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Expand call op operands and results according to the provided TypeConverter
+/// and ValueDecomposer.
+struct DecomposeCallGraphTypesForCallOp
+    : public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
+  using DecomposeCallGraphTypesOpConversionPattern::
+      DecomposeCallGraphTypesOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CallOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+
+    // Create the operands list of the new `CallOp`.
+    SmallVector<Value, 2> newOperands;
+    for (Value operand : operands)
+      decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
+                                operand, newOperands);
+
+    // Create the new result types for the new `CallOp` and track the indices in
+    // the new call op's results that correspond to the old call op's results.
+    //
+    // expandedResultIndices[i] = "list of new result indices that old result i
+    // expanded to".
+    SmallVector<Type, 2> newResultTypes;
+    SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
+    for (Type resultType : op.getResultTypes()) {
+      unsigned oldSize = newResultTypes.size();
+      getTypeConverter()->convertType(resultType, newResultTypes);
+      auto &resultMapping = expandedResultIndices.emplace_back();
+      for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
+        resultMapping.push_back(i);
+    }
+
+    CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCallee(),
+                                               newResultTypes, newOperands);
+
+    // Build a replacement value for each result to replace its uses. If a
+    // result has multiple mapping values, it needs to be materialized as a
+    // single value.
+    SmallVector<Value, 2> replacedValues;
+    replacedValues.reserve(op.getNumResults());
+    for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
+      auto decomposedValues = llvm::to_vector<6>(
+          llvm::map_range(expandedResultIndices[i],
+                          [&](unsigned i) { return newCallOp.getResult(i); }));
+      if (decomposedValues.empty()) {
+        // No replacement is required.
+        replacedValues.push_back(nullptr);
+      } else if (decomposedValues.size() == 1) {
+        replacedValues.push_back(decomposedValues.front());
+      } else {
+        // Materialize a single Value to replace the original Value.
+        Value materialized = getTypeConverter()->materializeArgumentConversion(
+            rewriter, op.getLoc(), op.getType(i), decomposedValues);
+        replacedValues.push_back(materialized);
+      }
+    }
+    rewriter.replaceOp(op, replacedValues);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateDecomposeCallGraphTypesPatterns(
+    MLIRContext *context, TypeConverter &typeConverter,
+    ValueDecomposer &decomposer, OwningRewritePatternList &patterns) {
+  patterns.insert<DecomposeCallGraphTypesForCallOp,
+                  DecomposeCallGraphTypesForFuncArgs,
+                  DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
+                                                      decomposer);
+}

diff  --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp
index 4ca446f06669..ba622335a396 100644
--- a/mlir/lib/Transforms/Bufferize.cpp
+++ b/mlir/lib/Transforms/Bufferize.cpp
@@ -41,28 +41,6 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
   });
 }
 
-/// This method tries to decompose a value of a certain type using provided
-/// decompose callback functions. If it is unable to do so, the original value
-/// is returned.
-void BufferizeTypeConverter::tryDecomposeValue(
-    OpBuilder &builder, Location loc, Type type, Value value,
-    SmallVectorImpl<Value> &results) {
-  for (auto &conversion : decomposeValueConversions)
-    if (conversion(builder, loc, type, value, results))
-      return;
-  results.push_back(value);
-}
-
-/// This method tries to decompose a type using provided decompose callback
-/// functions. If it is unable to do so, the original type is returned.
-void BufferizeTypeConverter::tryDecomposeType(Type type,
-                                              SmallVectorImpl<Type> &types) {
-  for (auto &conversion : decomposeTypeConversions)
-    if (conversion(type, types))
-      return;
-  types.push_back(type);
-}
-
 void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) {
   target.addLegalOp<TensorLoadOp, TensorToMemrefOp>();
 }
@@ -105,113 +83,3 @@ void mlir::populateEliminateBufferizeMaterializationsPatterns(
   patterns.insert<BufferizeTensorLoadOp, BufferizeTensorToMemrefOp>(
       typeConverter, context);
 }
-
-//===----------------------------------------------------------------------===//
-// BufferizeFuncOpConverter
-//===----------------------------------------------------------------------===//
-
-/// Performs the actual function signature rewriting step.
-LogicalResult BufferizeFuncOpConverter::matchAndRewrite(
-    mlir::FuncOp funcOp, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-  auto funcType = funcOp.getType();
-
-  // Convert function arguments using the provided TypeConverter.
-  TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
-  for (auto argType : llvm::enumerate(funcType.getInputs())) {
-    SmallVector<Type, 2> decomposedTypes, convertedTypes;
-    converter.tryDecomposeType(argType.value(), decomposedTypes);
-    converter.convertTypes(decomposedTypes, convertedTypes);
-    conversion.addInputs(argType.index(), convertedTypes);
-  }
-
-  // Convert the result types of the function.
-  SmallVector<Type, 2> newResultTypes;
-  newResultTypes.reserve(funcOp.getNumResults());
-  for (Type resultType : funcType.getResults()) {
-    SmallVector<Type, 2> originTypes;
-    converter.tryDecomposeType(resultType, originTypes);
-    for (auto origin : originTypes)
-      newResultTypes.push_back(converter.convertType(origin));
-  }
-
-  if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), converter,
-                                         &conversion)))
-    return failure();
-
-  // Update the signature of the function.
-  rewriter.updateRootInPlace(funcOp, [&] {
-    funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
-                                            newResultTypes));
-  });
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// BufferizeCallOpConverter
-//===----------------------------------------------------------------------===//
-
-/// Performs the actual rewriting step.
-LogicalResult BufferizeCallOpConverter::matchAndRewrite(
-    CallOp callOp, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-
-  Location loc = callOp.getLoc();
-  SmallVector<Value, 2> newOperands;
-
-  // TODO: if the CallOp references a FuncOp that only has a declaration (e.g.
-  // to an externally defined symbol like an external library calls), only
-  // convert if some special attribute is set.
-  // This will allow more control of interop across ABI boundaries.
-
-  // Create the operands list of the new `CallOp`. It unpacks the decomposable
-  // values if a decompose callback function has been provided by the user.
-  for (auto operand : operands)
-    converter.tryDecomposeValue(rewriter, loc, operand.getType(), operand,
-                                newOperands);
-
-  // Create the new result types for the new `CallOp` and track the indices in
-  // the new call op's results that correspond to the old call op's results.
-  SmallVector<Type, 2> newResultTypes;
-  SmallVector<SmallVector<int, 2>, 4> expandedResultIndices;
-  expandedResultIndices.resize(callOp.getNumResults());
-  for (auto result : llvm::enumerate(callOp.getResults())) {
-    SmallVector<Type, 2> originTypes;
-    converter.tryDecomposeType(result.value().getType(), originTypes);
-    auto &resultMapping = expandedResultIndices[result.index()];
-    for (Type origin : originTypes) {
-      Type converted = converter.convertType(origin);
-      newResultTypes.push_back(converted);
-      // The result value is not yet available. Its index is kept and it is
-      // replaced with the actual value of the new `CallOp` later.
-      resultMapping.push_back(newResultTypes.size() - 1);
-    }
-  }
-
-  CallOp newCallOp = rewriter.create<CallOp>(loc, callOp.getCallee(),
-                                             newResultTypes, newOperands);
-
-  // Build a replacing value for each result to replace its uses. If a result
-  // has multiple mapping values, it needs to be packed to a single value.
-  SmallVector<Value, 2> replacedValues;
-  replacedValues.reserve(callOp.getNumResults());
-  for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) {
-    auto valuesToPack = llvm::to_vector<6>(
-        llvm::map_range(expandedResultIndices[i],
-                        [&](int i) { return newCallOp.getResult(i); }));
-    if (valuesToPack.empty()) {
-      // No replacement is required.
-      replacedValues.push_back(nullptr);
-    } else if (valuesToPack.size() == 1) {
-      replacedValues.push_back(valuesToPack.front());
-    } else {
-      // Values need to be packed using callback function. The same callback
-      // that is used for materializeArgumentConversion is used for packing.
-      Value packed = converter.materializeArgumentConversion(
-          rewriter, loc, callOp.getType(i), valuesToPack);
-      replacedValues.push_back(packed);
-    }
-  }
-  rewriter.replaceOp(callOp, replacedValues);
-  return success();
-}

diff  --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir
new file mode 100644
index 000000000000..c29bbdd2bd9e
--- /dev/null
+++ b/mlir/test/Transforms/decompose-call-graph-types.mlir
@@ -0,0 +1,116 @@
+// RUN: mlir-opt %s -split-input-file -test-decompose-call-graph-types | FileCheck %s
+
+// Test case: Most basic case of a 1:N decomposition, an identity function.
+
+// CHECK-LABEL:   func @identity(
+// CHECK-SAME:                   %[[ARG0:.*]]: i1,
+// CHECK-SAME:                   %[[ARG1:.*]]: i32) -> (i1, i32) {
+// CHECK:           %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
+// CHECK:           %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
+// CHECK:           %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
+// CHECK:           return %[[RET0]], %[[RET1]] : i1, i32
+func @identity(%arg0: tuple<i1, i32>) -> tuple<i1, i32> {
+  return %arg0 : tuple<i1, i32>
+}
+
+// -----
+
+// Test case: Ensure no materializations in the case of 1:1 decomposition.
+
+// CHECK-LABEL:   func @identity_1_to_1_no_materializations(
+// CHECK-SAME:                                              %[[ARG0:.*]]: i1) -> i1 {
+// CHECK:           return %[[ARG0]] : i1
+func @identity_1_to_1_no_materializations(%arg0: tuple<i1>) -> tuple<i1> {
+  return %arg0 : tuple<i1>
+}
+
+// -----
+
+// Test case: Type that needs to be recursively decomposed.
+
+// CHECK-LABEL:   func @recursive_decomposition(
+// CHECK-SAME:                                   %[[ARG0:.*]]: i1) -> i1 {
+// CHECK:           return %[[ARG0]] : i1
+func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tuple<tuple<i1>>> {
+  return %arg0 : tuple<tuple<tuple<i1>>>
+}
+
+// -----
+
+// Test case: Check decomposition of calls.
+
+// CHECK-LABEL:   func @callee(i1, i32) -> (i1, i32)
+func @callee(tuple<i1, i32>) -> tuple<i1, i32>
+
+// CHECK-LABEL:   func @caller(
+// CHECK-SAME:                 %[[ARG0:.*]]: i1,
+// CHECK-SAME:                 %[[ARG1:.*]]: i32) -> (i1, i32) {
+// CHECK:           %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
+// CHECK:           %[[CALL_ARG0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
+// CHECK:           %[[CALL_ARG1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
+// CHECK:           %[[DECOMPOSED:.*]]:2 = call @callee(%[[CALL_ARG0]], %[[CALL_ARG1]]) : (i1, i32) -> (i1, i32)
+// CHECK:           %[[CALL_RESULT_RECOMPOSED:.*]] = "test.make_tuple"(%[[DECOMPOSED]]#0, %[[DECOMPOSED]]#1) : (i1, i32) -> tuple<i1, i32>
+// CHECK:           %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
+// CHECK:           %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
+// CHECK:           return %[[RET0]], %[[RET1]] : i1, i32
+func @caller(%arg0: tuple<i1, i32>) -> tuple<i1, i32> {
+  %0 = call @callee(%arg0) : (tuple<i1, i32>) -> tuple<i1, i32>
+  return %0 : tuple<i1, i32>
+}
+
+// -----
+
+// Test case: Type that decomposes to nothing (that is, a 1:0 decomposition).
+
+// CHECK-LABEL:   func @callee()
+func @callee(tuple<>) -> tuple<>
+// CHECK-LABEL:   func @caller() {
+// CHECK:           call @callee() : () -> ()
+// CHECK:           return
+func @caller(%arg0: tuple<>) -> tuple<> {
+  %0 = call @callee(%arg0) : (tuple<>) -> (tuple<>)
+  return %0 : tuple<>
+}
+
+// -----
+
+// Test case: Ensure decompositions are inserted properly around results of
+// unconverted ops.
+
+// CHECK-LABEL:   func @unconverted_op_result() -> (i1, i32) {
+// CHECK:           %[[UNCONVERTED_VALUE:.*]] = "test.source"() : () -> tuple<i1, i32>
+// CHECK:           %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
+// CHECK:           %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
+// CHECK:           return %[[RET0]], %[[RET1]] : i1, i32
+func @unconverted_op_result() -> tuple<i1, i32> {
+  %0 = "test.source"() : () -> (tuple<i1, i32>)
+  return %0 : tuple<i1, i32>
+}
+
+// -----
+
+// Test case: Check mixed decomposed and non-decomposed args.
+// This makes sure to test the cases if 1:0, 1:1, and 1:N decompositions.
+
+// CHECK-LABEL:   func @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
+func @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6)
+
+// CHECK-LABEL:   func @caller(
+// CHECK-SAME:                 %[[I1:.*]]: i1,
+// CHECK-SAME:                 %[[I2:.*]]: i2,
+// CHECK-SAME:                 %[[I3:.*]]: i3,
+// CHECK-SAME:                 %[[I4:.*]]: i4,
+// CHECK-SAME:                 %[[I5:.*]]: i5,
+// CHECK-SAME:                 %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) {
+// CHECK:           %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[I4]], %[[I5]]) : (i4, i5) -> tuple<i4, i5>
+// CHECK:           %[[ARG_TUPLE_0:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 0 : i32} : (tuple<i4, i5>) -> i4
+// CHECK:           %[[ARG_TUPLE_1:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 1 : i32} : (tuple<i4, i5>) -> i5
+// CHECK:           %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[ARG_TUPLE_0]], %[[ARG_TUPLE_1]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
+// CHECK:           %[[RET_TUPLE:.*]] = "test.make_tuple"(%[[CALL]]#3, %[[CALL]]#4) : (i4, i5) -> tuple<i4, i5>
+// CHECK:           %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 0 : i32} : (tuple<i4, i5>) -> i4
+// CHECK:           %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 1 : i32} : (tuple<i4, i5>) -> i5
+// CHECK:           return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
+func @caller(%arg0: tuple<>, %arg1: i1, %arg2: tuple<i2>, %arg3: i3, %arg4: tuple<i4, i5>, %arg5: i6) -> (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) {
+  %0, %1, %2, %3, %4, %5 = call @callee(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6)
+  return %0, %1, %2, %3, %4, %5 : tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6
+}

diff  --git a/mlir/test/Transforms/finalizing-bufferize.mlir b/mlir/test/Transforms/finalizing-bufferize.mlir
deleted file mode 100644
index 2dc16317869e..000000000000
--- a/mlir/test/Transforms/finalizing-bufferize.mlir
+++ /dev/null
@@ -1,180 +0,0 @@
-// RUN: mlir-opt -test-finalizing-bufferize -split-input-file %s | FileCheck %s
-
-// CHECK-LABEL: func @void_function_signature_conversion
-func @void_function_signature_conversion(%arg0: tensor<4x8xf32>) {
-    return
-}
-// CHECK: ({{.*}}: memref<4x8xf32>)
-
-// -----
-
-// CHECK-LABEL: func @complex_signature_conversion
-func @complex_signature_conversion(
-  %arg0: tensor<5xf32>,
-  %arg1: memref<10xf32>,
-  %arg2: i1,
-  %arg3: f16) -> (
-    i1,
-    tensor<5xf32>,
-    memref<10xf32>,
-    memref<15xf32>,
-    f16) {
-  %0 = alloc() : memref<15xf32>
-  %1 = test.tensor_based in(%arg0 : tensor<5xf32>) -> tensor<5xf32>
-  return %arg2, %1, %arg1, %0, %arg3 :
-   i1, tensor<5xf32>, memref<10xf32>, memref<15xf32>, f16
-}
-//      CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>,
-// CHECK-SAME: %[[ARG2:.*]]: i1, %[[ARG3:.*]]: f16)
-// CHECK-SAME: (i1, memref<5xf32>, memref<10xf32>, memref<15xf32>, f16)
-//      CHECK: %[[FIRST_ALLOC:.*]] = alloc()
-//      CHECK: %[[TENSOR_ALLOC:.*]] = alloc()
-//      CHECK: return %[[ARG2]], %[[TENSOR_ALLOC]], %[[ARG1]], %[[FIRST_ALLOC]],
-// CHECK-SAME: %[[ARG3]]
-
-// -----
-
-// CHECK-LABEL: func @no_signature_conversion_is_needed
-func @no_signature_conversion_is_needed(%arg0: memref<4x8xf32>) {
-  return
-}
-// CHECK: ({{.*}}: memref<4x8xf32>)
-
-// -----
-
-// CHECK-LABEL: func @no_signature_conversion_is_needed
-func @no_signature_conversion_is_needed(%arg0: i1, %arg1: f16) -> (i1, f16){
-  return %arg0, %arg1 : i1, f16
-}
-// CHECK: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: f16) -> (i1, f16)
-// CHECK: return %[[ARG0]], %[[ARG1]]
-
-// -----
-
-// CHECK-LABEL: func @simple_signature_conversion
-func @simple_signature_conversion(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
-  return %arg0 : tensor<4x8xf32>
-}
-//      CHECK: (%[[ARG0:.*]]: [[TYPE:.*]]<[[RANK:.*]]>) -> [[TYPE]]<[[RANK]]>
-// CHECK-NEXT: return %[[ARG0]]
-
-// -----
-
-// CHECK-LABEL: func @func_with_unranked_arg_and_result
-func @func_with_unranked_arg_and_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
-  return %arg0 : tensor<*xf32>
-}
-// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32>
-// CHECK-NEXT: return [[ARG]] : memref<*xf32>
-
-// -----
-
-// CHECK-LABEL: func @func_and_block_signature_conversion
-func @func_and_block_signature_conversion(%arg0 : tensor<2xf32>, %cond : i1, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32>{
-    cond_br %cond, ^bb1, ^bb2
-  ^bb1:
-    br ^exit(%arg0 : tensor<2xf32>)
-  ^bb2:
-    br ^exit(%arg0 : tensor<2xf32>)
-  ^exit(%arg2: tensor<2xf32>):
-    return %arg1 : tensor<4x4xf32>
-}
-//      CHECK: (%[[ARG0:.*]]: [[ARG0_TYPE:.*]], %[[COND:.*]]: i1, %[[ARG1:.*]]: [[ARG1_TYPE:.*]]) -> [[RESULT_TYPE:.*]] {
-//      CHECK: br ^[[EXIT_BLOCK:.*]](%[[ARG0]] : [[ARG0_TYPE]])
-//      CHECK: br ^[[EXIT_BLOCK]](%[[ARG0]] : [[ARG0_TYPE]])
-//      CHECK: ^[[EXIT_BLOCK]](%{{.*}}: [[ARG0_TYPE]])
-// CHECK-NEXT:  return %[[ARG1]] : [[RESULT_TYPE]]
-
-// -----
-
-// CHECK-LABEL: func @callee
-func @callee(%arg1: tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>) {
-  %buff = alloc() : memref<2xf32>
-  return %arg1, %buff : tensor<5xf32>, memref<2xf32>
-}
-// CHECK: (%[[CALLEE_ARG:.*]]: memref<5xf32>) -> (memref<5xf32>, memref<2xf32>)
-// CHECK: %[[ALLOC:.*]] = alloc()
-// CHECK: return %[[CALLEE_ARG]], %[[ALLOC]]
-
-// CHECK-LABEL: func @caller
-func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
-  %x:2 = call @callee(%arg0) : (tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>)
-  %y:2 = call @callee(%x#0) : (tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>)
-  return %y#0 : tensor<5xf32>
-}
-// CHECK: (%[[CALLER_ARG:.*]]: memref<5xf32>) -> memref<5xf32>
-// CHECK: %[[X:.*]]:2 = call @callee(%[[CALLER_ARG]])
-// CHECK: %[[Y:.*]]:2 = call @callee(%[[X]]#0)
-// CHECK: return %[[Y]]#0
-
-// -----
-
-// Test case: Testing BufferizeCallOpConverter to see if it matches with the
-// signature of the new signature of the callee function when there are tuple
-// typed args and results. BufferizeTypeConverter is set to flatten tuple typed
-// arguments. The tuple typed values should be decomposed and composed using
-// get_tuple_element and make_tuple operations of test dialect. Tensor types are
-// converted to Memref. Memref typed function results remain as function
-// results.
-
-// CHECK-LABEL: func @callee
-func @callee(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>){
-  return %arg0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
-}
-// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>)
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>)
-// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]]  = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]
-
-// CHECK-LABEL: func @caller
-func @caller(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> tuple<tensor<2xf32>,i1, tensor<5xf32>>{
-  %x0 = call @callee(%arg0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
-  %y0 = call @callee(%x0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
-  return %y0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
-}
-// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>)
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>)
-// CHECK-NEXT: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]]  = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]])
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>)
-// CHECK-NEXT: %[[RESULT_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2)
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]]  = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]])
-// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>)
-// CHECK-NEXT: %[[RETURN_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2)
-// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[THIRD_ELEM:.*]]  = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 2 : i32}
-// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]
-
-// -----
-
-// Test case: Testing BufferizeFuncOpConverter and
-// BufferizeReturnOpConverter to see if the return operation matches with the
-// new function signature when there are tuple typed args and results.
-// BufferizeTypeConverter is set to flatten tuple typed arguments. The tuple
-// typed values should be decomposed and composed using get_tuple_element and
-// make_tuple operations of test dialect. Tensor types are converted to Memref.
-// Memref typed function results remain as function results.
-
-// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results
-func @decompose_tuple_typed_function_args_and_results(%arg0: tuple<i1,f32>, %arg1: tensor<10xf32>, %arg2: tuple<i1, tensor<5xf32>>) -> (tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>){
-  return %arg2, %arg1, %arg0 : tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>
-}
-// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>
-// CHECK-SAME: (i1, memref<5xf32>, memref<10xf32>, i1, f32)
-// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]])
-// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]])
-// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]]  = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32}
-// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32}
-// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[SECOND_TUPLE_SECOND_ELEM]], %[[ARG2]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]]

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 69d45b570a3c..5dfe1e82c75a 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRTestTransforms
   TestAffineLoopParametricTiling.cpp
   TestExpandTanh.cpp
   TestCallGraph.cpp
+  TestDecomposeCallGraphTypes.cpp
   TestConstantFold.cpp
   TestConvVectorization.cpp
   TestConvertCallOp.cpp
@@ -10,7 +11,6 @@ add_mlir_library(MLIRTestTransforms
   TestConvertGPUKernelToHsaco.cpp
   TestDominance.cpp
   TestDynamicPipeline.cpp
-  TestFinalizingBufferize.cpp
   TestLoopFusion.cpp
   TestGpuMemoryPromotion.cpp
   TestGpuParallelLoopMapping.cpp

diff  --git a/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp
new file mode 100644
index 000000000000..26a0ae1f57b7
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp
@@ -0,0 +1,97 @@
+//===- TestDecomposeCallGraphTypes.cpp - Test CG type decomposition -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+/// A pass for testing call graph type decomposition.
+///
+/// This instantiates the patterns with a TypeConverter and ValueDecomposer
+/// that splits tuple types into their respective element types.
+/// For example, `tuple<T1, T2, T3> --> T1, T2, T3`.
+struct TestDecomposeCallGraphTypes
+    : public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> {
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<test::TestDialect>();
+  }
+  void runOnOperation() override {
+    ModuleOp module = getOperation();
+    auto *context = &getContext();
+    TypeConverter typeConverter;
+    ConversionTarget target(*context);
+    ValueDecomposer decomposer;
+    OwningRewritePatternList patterns;
+
+    target.addLegalDialect<test::TestDialect>();
+
+    target.addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) {
+      return typeConverter.isLegal(op.getOperandTypes());
+    });
+    target.addDynamicallyLegalOp<CallOp>(
+        [&](CallOp op) { return typeConverter.isLegal(op); });
+    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+      return typeConverter.isSignatureLegal(op.getType());
+    });
+
+    typeConverter.addConversion([](Type type) { return type; });
+    typeConverter.addConversion(
+        [](TupleType tupleType, SmallVectorImpl<Type> &types) {
+          tupleType.getFlattenedTypes(types);
+          return success();
+        });
+
+    decomposer.addDecomposeValueConversion([](OpBuilder &builder, Location loc,
+                                              TupleType resultType, Value value,
+                                              SmallVectorImpl<Value> &values) {
+      for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
+        Value res = builder.create<test::GetTupleElementOp>(
+            loc, resultType.getType(i), value, builder.getI32IntegerAttr(i));
+        values.push_back(res);
+      }
+      return success();
+    });
+
+    typeConverter.addArgumentMaterialization(
+        [](OpBuilder &builder, TupleType resultType, ValueRange inputs,
+           Location loc) -> Optional<Value> {
+          if (inputs.size() == 1)
+            return llvm::None;
+          TypeRange TypeRange = inputs.getTypes();
+          SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end());
+          TupleType tuple = TupleType::get(types, builder.getContext());
+          Value value = builder.create<test::MakeTupleOp>(loc, tuple, inputs);
+          return value;
+        });
+
+    populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer,
+                                            patterns);
+
+    if (failed(applyPartialConversion(module, target, std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestDecomposeCallGraphTypes() {
+  PassRegistration<TestDecomposeCallGraphTypes> pass(
+      "test-decompose-call-graph-types",
+      "Decomposes types at call graph boundaries.");
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/test/lib/Transforms/TestFinalizingBufferize.cpp b/mlir/test/lib/Transforms/TestFinalizingBufferize.cpp
deleted file mode 100644
index b9001f3d52dd..000000000000
--- a/mlir/test/lib/Transforms/TestFinalizingBufferize.cpp
+++ /dev/null
@@ -1,167 +0,0 @@
-//===- TestFinalizingBufferize.cpp - Finalizing bufferization ---*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements a pass that exercises the functionality of finalizing
-// bufferizations.
-//
-//===----------------------------------------------------------------------===//
-
-#include "TestDialect.h"
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Transforms/Bufferize.h"
-
-using namespace mlir;
-
-namespace {
-/// This pass is a test for "finalizing" bufferize conversions.
-///
-/// A "finalizing" bufferize conversion is one that performs a "full" conversion
-/// and expects all tensors to be gone from the program. This in particular
-/// involves rewriting funcs (including block arguments of the contained
-/// region), calls, and returns. The unique property of finalizing bufferization
-/// passes is that they cannot be done via a local transformation with suitable
-/// materializations to ensure composability (as other bufferization passes do).
-/// For example, if a call is rewritten, the callee needs to be rewritten
-/// otherwise the IR will end up invalid. Thus, finalizing bufferization passes
-/// require an atomic change to the entire program (e.g. the whole module).
-///
-/// TODO: Split out BufferizeFinalizationPolicy from BufferizeTypeConverter.
-struct TestFinalizingBufferizePass
-    : mlir::PassWrapper<TestFinalizingBufferizePass, OperationPass<ModuleOp>> {
-
-  /// Converts tensor based test operations to buffer based ones using
-  /// bufferize.
-  class TensorBasedOpConverter
-      : public BufferizeOpConversionPattern<test::TensorBasedOp> {
-  public:
-    using BufferizeOpConversionPattern<
-        test::TensorBasedOp>::BufferizeOpConversionPattern;
-
-    LogicalResult
-    matchAndRewrite(test::TensorBasedOp op, ArrayRef<Value> operands,
-                    ConversionPatternRewriter &rewriter) const final {
-      mlir::test::TensorBasedOpAdaptor adaptor(
-          operands, op.getOperation()->getAttrDictionary());
-
-      // The input needs to be turned into a buffer first. Until then, bail out.
-      if (!adaptor.input().getType().isa<MemRefType>())
-        return failure();
-
-      Location loc = op.getLoc();
-
-      // Update the result type to a memref type.
-      auto type = op.getResult().getType().cast<ShapedType>();
-      if (!type.hasStaticShape())
-        return rewriter.notifyMatchFailure(
-            op, "dynamic shapes not currently supported");
-      auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
-      Value newOutputBuffer = rewriter.create<AllocOp>(loc, memrefType);
-
-      // Generate a new test operation that works on buffers.
-      rewriter.create<mlir::test::BufferBasedOp>(loc,
-                                                 /*input=*/adaptor.input(),
-                                                 /*output=*/newOutputBuffer);
-
-      // Replace the results of the old op with the new output buffers.
-      rewriter.replaceOp(op, newOutputBuffer);
-      return success();
-    }
-  };
-
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<test::TestDialect>();
-  }
-
-  void runOnOperation() override {
-    MLIRContext &context = this->getContext();
-    ConversionTarget target(context);
-    BufferizeTypeConverter converter;
-
-    // Mark all Standard operations legal.
-    target.addLegalDialect<StandardOpsDialect>();
-    target.addLegalOp<test::MakeTupleOp>();
-    target.addLegalOp<test::GetTupleElementOp>();
-    target.addLegalOp<ModuleOp>();
-    target.addLegalOp<ModuleTerminatorOp>();
-
-    // Mark all Test operations illegal as long as they work on tensors.
-    auto isLegalOperation = [&](Operation *op) {
-      return converter.isLegal(op);
-    };
-    target.addDynamicallyLegalDialect<test::TestDialect>(isLegalOperation);
-
-    // Mark Standard Return operations illegal as long as one operand is tensor.
-    target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
-      return converter.isLegal(returnOp.getOperandTypes());
-    });
-
-    // Mark Standard Call Operation illegal as long as it operates on tensor.
-    target.addDynamicallyLegalOp<mlir::CallOp>(
-        [&](mlir::CallOp callOp) { return converter.isLegal(callOp); });
-
-    // Mark the function whose arguments are in tensor-type illegal.
-    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {
-      return converter.isSignatureLegal(funcOp.getType()) &&
-             converter.isLegal(&funcOp.getBody());
-    });
-
-    converter.addDecomposeTypeConversion(
-        [](TupleType tupleType, SmallVectorImpl<Type> &types) {
-          tupleType.getFlattenedTypes(types);
-          return success();
-        });
-
-    converter.addArgumentMaterialization(
-        [](OpBuilder &builder, TupleType resultType, ValueRange inputs,
-           Location loc) -> Optional<Value> {
-          if (inputs.size() == 1)
-            return llvm::None;
-          TypeRange TypeRange = inputs.getTypes();
-          SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end());
-          TupleType tuple = TupleType::get(types, builder.getContext());
-          mlir::Value value =
-              builder.create<test::MakeTupleOp>(loc, tuple, inputs);
-          return value;
-        });
-
-    converter.addDecomposeValueConversion([](OpBuilder &builder, Location loc,
-                                             TupleType resultType, Value value,
-                                             SmallVectorImpl<Value> &values) {
-      for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
-        Value res = builder.create<test::GetTupleElementOp>(
-            loc, resultType.getType(i), value, builder.getI32IntegerAttr(i));
-        values.push_back(res);
-      }
-      return success();
-    });
-
-    OwningRewritePatternList patterns;
-    populateWithBufferizeOpConversionPatterns<ReturnOp, ReturnOp, test::CopyOp>(
-        &context, converter, patterns);
-    patterns.insert<TensorBasedOpConverter>(&context, converter);
-
-    if (failed(applyFullConversion(this->getOperation(), target,
-                                   std::move(patterns))))
-      this->signalPassFailure();
-  };
-};
-} // end anonymous namespace
-
-namespace mlir {
-namespace test {
-void registerTestFinalizingBufferizePass() {
-  PassRegistration<TestFinalizingBufferizePass>(
-      "test-finalizing-bufferize", "Tests finalizing bufferize conversions");
-}
-} // namespace test
-} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 58444e6a9501..0b4a66e37987 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -63,11 +63,11 @@ void registerTestConstantFold();
 void registerTestConvVectorization();
 void registerTestConvertGPUKernelToCubinPass();
 void registerTestConvertGPUKernelToHsacoPass();
+void registerTestDecomposeCallGraphTypes();
 void registerTestDialect(DialectRegistry &);
 void registerTestDominancePass();
 void registerTestDynamicPipelinePass();
 void registerTestExpandTanhPass();
-void registerTestFinalizingBufferizePass();
 void registerTestGpuParallelLoopMappingPass();
 void registerTestInterfaces();
 void registerTestLinalgCodegenStrategy();
@@ -130,10 +130,10 @@ void registerTestPasses() {
   test::registerTestConvertGPUKernelToHsacoPass();
 #endif
   test::registerTestConvVectorization();
+  test::registerTestDecomposeCallGraphTypes();
   test::registerTestDominancePass();
   test::registerTestDynamicPipelinePass();
   test::registerTestExpandTanhPass();
-  test::registerTestFinalizingBufferizePass();
   test::registerTestGpuParallelLoopMappingPass();
   test::registerTestInterfaces();
   test::registerTestLinalgCodegenStrategy();


        


More information about the Mlir-commits mailing list