[Mlir-commits] [mlir] 4214031 - [mlir] Introduce allowMemrefFunctionResults for the helper operation converters of buffer placement

Ehsan Toosi llvmlistbot at llvm.org
Mon Jun 8 00:31:04 PDT 2020


Author: Ehsan Toosi
Date: 2020-06-08T09:25:41+02:00
New Revision: 4214031d4337a6b04ae4c28119305182e37c45bc

URL: https://github.com/llvm/llvm-project/commit/4214031d4337a6b04ae4c28119305182e37c45bc
DIFF: https://github.com/llvm/llvm-project/commit/4214031d4337a6b04ae4c28119305182e37c45bc.diff

LOG: [mlir] Introduce allowMemrefFunctionResults for the helper operation converters of buffer placement

This parameter gives the developers the freedom to choose their desired function
signature conversion for preparing their functions for buffer placement. It is
introduced for BufferAssignmentFuncOpConverter, and also for
BufferAssignmentReturnOpConverter, and BufferAssignmentCallOpConverter to adapt
the return and call operations with the selected function signature conversion.
If the parameter is set, buffer placement won't also deallocate the returned
buffers.

Differential Revision: https://reviews.llvm.org/D81137

Added: 
    mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir

Modified: 
    mlir/include/mlir/Transforms/BufferPlacement.h
    mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
    mlir/lib/Transforms/BufferPlacement.cpp
    mlir/test/Transforms/buffer-placement-preparation.mlir
    mlir/test/lib/Transforms/TestBufferPlacement.cpp
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h
index 89cb4b043a19..547db487e454 100644
--- a/mlir/include/mlir/Transforms/BufferPlacement.h
+++ b/mlir/include/mlir/Transforms/BufferPlacement.h
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -88,12 +89,23 @@ class BufferAssignmentTypeConverter : public TypeConverter {
   static bool isConvertedMemref(Type type, Type before);
 };
 
-/// Converts the signature of the function using the type converter. It adds an
-/// extra argument for each function result type which is going to be a memref
-/// type after type conversion. The other function result types remain
-/// unchanged. `BufferAssignmentTypeConverter` is a helper `TypeConverter` for
-/// this purpose.
-class FunctionAndBlockSignatureConverter
+namespace detail {
+
+/// Converts the signature of the function based on whether the function is
+/// allowed to return memref typed results or not using
+/// `allowMemrefFunctionResults` parameter. If this option is false, then it
+/// adds an extra function argument as an output buffer for each function result
+/// which is going to be a memref type only after type conversion. The
+/// other function result types remain unchanged. If
+/// `allowMemrefFunctionResults` is true, the types are converted in place.
+/// Any changes in function signature need to be applied
+/// to return and caller operations. `BufferAssignmentReturnOpConverter` and
+/// `BufferAssignmentCallOpConverter` are two helper function that match the
+/// return and caller operation with the new function signature. Furthermore,
+/// `BufferAssignmentTypeConverter` is a helper `TypeConverter` for converting
+/// tensor typed values to memref typed ones.
+template <bool allowMemrefFunctionResults>
+class BufferAssignmentFuncOpConverter
     : public BufferAssignmentOpConversionPattern<FuncOp> {
 public:
   using BufferAssignmentOpConversionPattern<
@@ -101,17 +113,55 @@ class FunctionAndBlockSignatureConverter
 
   /// Performs the actual signature rewriting step.
   LogicalResult
-  matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const final;
+  matchAndRewrite(mlir::FuncOp funcOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    if (!converter)
+      return funcOp.emitError("The type converter has not been defined for "
+                              "BufferAssignmentFuncOpConverter");
+    auto funcType = funcOp.getType();
+
+    // Convert function arguments using the provided TypeConverter.
+    TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
+    for (auto argType : llvm::enumerate(funcType.getInputs()))
+      conversion.addInputs(argType.index(),
+                           converter->convertType(argType.value()));
+
+    // If allowMemrefFunctionResults is false and a function result type is not
+    // a memref but it would be a memref after type conversion, a new argument
+    // should be appended to the function arguments list for this result.
+    // Otherwise, it remains unchanged as a function result.
+    SmallVector<Type, 2> newResultTypes;
+    newResultTypes.reserve(funcOp.getNumResults());
+    for (Type resType : funcType.getResults()) {
+      Type convertedType = converter->convertType(resType);
+      if (!allowMemrefFunctionResults &&
+          BufferAssignmentTypeConverter::isConvertedMemref(convertedType,
+                                                           resType))
+        conversion.addInputs(convertedType);
+      else
+        newResultTypes.push_back(convertedType);
+    }
+
+    // Update the signature of the function.
+    rewriter.updateRootInPlace(funcOp, [&] {
+      funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
+                                              newResultTypes));
+      rewriter.applySignatureConversion(&funcOp.getBody(), conversion);
+    });
+    return success();
+  }
 };
 
 /// Rewrites the `ReturnOp` to conform with the changed function signature.
-/// Operands that correspond to return values that have been rewritten from
-/// tensor results to memref arguments are dropped. In their place, a
-/// corresponding copy operation from the operand to the new function argument
-/// is inserted.
+/// if allowMemrefFunctionResults is false, operands that correspond to return
+/// values and have been rewritten from illegal typed results to memref
+/// arguments are dropped. In their place, a corresponding copy operation from
+/// the operand to the output function argument is inserted. Otherwise, the
+/// memref typed operands are returned.
+/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
+/// allowMemrefFunctionResults must be set/unset for both.
 template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
-          typename CopyOpTy>
+          typename CopyOpTy, bool allowMemrefFunctionResults>
 class BufferAssignmentReturnOpConverter
     : public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> {
 public:
@@ -122,6 +172,13 @@ class BufferAssignmentReturnOpConverter
   LogicalResult
   matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
+    // If the memref typed results can be returned as function results, the new
+    // `ReturnOp` should only return the type converted operands.
+    if (allowMemrefFunctionResults) {
+      rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, operands);
+      return success();
+    }
+
     // Split the operands by their kinds whether they are converted memref or
     // not.
     SmallVector<Value, 2> needCopyOperands, newOperands;
@@ -158,20 +215,99 @@ class BufferAssignmentReturnOpConverter
   }
 };
 
-/// Converts `CallOp` to match its operands and results with the
-/// the callee after rewriting the callee with
-/// FunctionAndBlockSignatureConverter.
+/// Rewrites the `CallOp` to match its operands and results with the signature
+/// of the callee after rewriting the callee with
+/// BufferAssignmentFuncOpConverter. If allowMemrefFunctionResults is false, a
+/// buffer is allocated as an output buffer only for each memref typed result
+/// that has been rewritten. The new allocated buffer is passed through the
+/// operands list of the new `CallOp`.
+/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
+/// allowMemrefFunctionResults must be set/unset for both.
+template <bool allowMemrefFunctionResults>
 class BufferAssignmentCallOpConverter
     : public BufferAssignmentOpConversionPattern<CallOp> {
 public:
   using BufferAssignmentOpConversionPattern<
       CallOp>::BufferAssignmentOpConversionPattern;
 
-  /// Performs the actual `CallOp` conversion step.
   LogicalResult
   matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const final;
+                  ConversionPatternRewriter &rewriter) const final {
+    if (!converter)
+      return callOp.emitError("The type converter has not been defined for "
+                              "BufferAssignmentCallOpConverter");
+    Location loc = callOp.getLoc();
+
+    // If the memref typed results can be returned as function results, there is
+    // no need to create output buffers. It is only required to convert the type
+    // of operands and results in place for creating the new `CallOp`.
+    if (allowMemrefFunctionResults) {
+      SmallVector<Type, 2> resultTypes;
+      resultTypes.reserve(callOp.getNumResults());
+      for (Type type : callOp.getResultTypes())
+        resultTypes.push_back(converter->convertType(type));
+      rewriter.replaceOpWithNewOp<CallOp>(callOp, callOp.getCallee(),
+                                          resultTypes, operands);
+      return success();
+    }
+
+    SmallVector<Value, 2> newOperands, replacingValues;
+    SmallVector<Type, 2> newResultTypes;
+    unsigned numResults = callOp.getNumResults();
+    newOperands.reserve(numResults + operands.size());
+    newOperands.append(operands.begin(), operands.end());
+    newResultTypes.reserve(numResults);
+    replacingValues.reserve(numResults);
+
+    // For each memref result of `CallOp` which has not been a memref before
+    // the type conversion, a new buffer is allocated and passed to the operands
+    // list of the new `CallOp`. Otherwise, it remains as a caller result.
+    for (Value result : callOp.getResults()) {
+      Type currType = result.getType();
+      Type newType = converter->convertType(result.getType());
+      if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) {
+        OpBuilder::InsertionGuard guard(rewriter);
+        rewriter.restoreInsertionPoint(bufferAssignment->computeAllocPosition(
+            result.dyn_cast<OpResult>()));
+        Value alloc =
+            rewriter.create<AllocOp>(loc, newType.dyn_cast<MemRefType>());
+        newOperands.push_back(alloc);
+        replacingValues.push_back(alloc);
+      } else {
+        newResultTypes.push_back(currType);
+
+        // No replacing is required.
+        replacingValues.push_back(nullptr);
+      }
+    }
+
+    // Creating the new `CallOp`.
+    rewriter.create<CallOp>(loc, callOp.getCallee(), newResultTypes,
+                            newOperands);
+
+    // Replacing the results of the old `CallOp`.
+    rewriter.replaceOp(callOp, replacingValues);
+    return success();
+  }
 };
+} // end namespace detail
+
+/// Populates `patterns` with the conversion patterns of buffer
+/// assignment.
+template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
+          typename CopyOpTy, bool allowMemrefFunctionResults>
+static void populateWithBufferAssignmentOpConversionPatterns(
+    MLIRContext *context, BufferAssignmentPlacer *placer,
+    TypeConverter *converter, OwningRewritePatternList *patterns) {
+  // clang-format off
+  patterns->insert<
+    detail::BufferAssignmentCallOpConverter<allowMemrefFunctionResults>,
+    detail::BufferAssignmentFuncOpConverter<allowMemrefFunctionResults>,
+    detail::BufferAssignmentReturnOpConverter
+      <ReturnOpSourceTy, ReturnOpTargetTy, CopyOpTy, allowMemrefFunctionResults>
+  >(context, placer, converter);
+  // clang-format on
+}
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_BUFFERPLACEMENT_H

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index c663eb6017e5..1f983e802eab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -20,9 +20,6 @@
 #include "mlir/Transforms/BufferPlacement.h"
 
 using namespace mlir;
-using ReturnOpConverter =
-    BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
-                                      linalg::CopyOp>;
 
 namespace {
 /// A pattern to convert Generic Linalg operations which work on tensors to
@@ -103,11 +100,11 @@ class GenericOpConverter
 static void populateConvertLinalgOnTensorsToBuffersPattern(
     MLIRContext *context, BufferAssignmentPlacer *placer,
     TypeConverter *converter, OwningRewritePatternList *patterns) {
-  // clang-format off
-  patterns->insert<FunctionAndBlockSignatureConverter,
-                   GenericOpConverter,
-                   ReturnOpConverter>(context, placer, converter);
-  // clang-format on
+  populateWithBufferAssignmentOpConversionPatterns<
+      mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
+      /*allowMemrefFunctionResults=*/false>(context, placer, converter,
+                                            patterns);
+  patterns->insert<GenericOpConverter>(context, placer, converter);
 }
 
 /// Converts Linalg operations that work on tensor-type operands or results to

diff  --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp
index edbaf8e39c97..0bca5cf3e8b2 100644
--- a/mlir/lib/Transforms/BufferPlacement.cpp
+++ b/mlir/lib/Transforms/BufferPlacement.cpp
@@ -49,8 +49,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Transforms/BufferPlacement.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Operation.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/Passes.h"
 
@@ -423,102 +421,6 @@ BufferAssignmentPlacer::computeAllocPosition(OpResult result) {
   return OpBuilder::InsertPoint(owner->getBlock(), Block::iterator(owner));
 }
 
-//===----------------------------------------------------------------------===//
-// FunctionAndBlockSignatureConverter
-//===----------------------------------------------------------------------===//
-
-// Performs the actual signature rewriting step.
-LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite(
-    FuncOp funcOp, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-  if (!converter) {
-    funcOp.emitError("The type converter has not been defined for "
-                     "FunctionAndBlockSignatureConverter");
-    return failure();
-  }
-  auto funcType = funcOp.getType();
-
-  // Convert function arguments using the provided TypeConverter.
-  TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
-  for (auto argType : llvm::enumerate(funcType.getInputs()))
-    conversion.addInputs(argType.index(),
-                         converter->convertType(argType.value()));
-
-  // If a function result type is not a memref but it would be a memref after
-  // type conversion, a new argument should be appended to the function
-  // arguments list for this result. Otherwise, it remains unchanged as a
-  // function result.
-  SmallVector<Type, 2> newResultTypes;
-  newResultTypes.reserve(funcOp.getNumResults());
-  for (Type resType : funcType.getResults()) {
-    Type convertedType = converter->convertType(resType);
-    if (BufferAssignmentTypeConverter::isConvertedMemref(convertedType,
-                                                         resType))
-      conversion.addInputs(convertedType);
-    else
-      newResultTypes.push_back(convertedType);
-  }
-
-  // Update the signature of the function.
-  rewriter.updateRootInPlace(funcOp, [&] {
-    funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
-                                            newResultTypes));
-    rewriter.applySignatureConversion(&funcOp.getBody(), conversion);
-  });
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// BufferAssignmentCallOpConverter
-//===----------------------------------------------------------------------===//
-
-// Performs `CallOp` conversion to match its operands and results with the
-// signature of the callee after rewriting the callee with
-// FunctionAndBlockSignatureConverter.
-LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
-    CallOp callOp, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-
-  Location loc = callOp.getLoc();
-  SmallVector<Value, 2> newOperands, replacingValues;
-  SmallVector<Type, 2> newResultTypes;
-  unsigned numResults = callOp.getNumResults();
-  newOperands.reserve(numResults + operands.size());
-  newOperands.append(operands.begin(), operands.end());
-  newResultTypes.reserve(numResults);
-  replacingValues.reserve(numResults);
-
-  // For each memref result of `CallOp` which has not been a memref before type
-  // conversion, a new buffer is allocated and passed to the operands list of
-  // the new `CallOp`. Otherwise, it remains as a caller result.
-  for (Value result : callOp.getResults()) {
-    Type currType = result.getType();
-    Type newType = converter->convertType(result.getType());
-    if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) {
-      OpBuilder::InsertionGuard guard(rewriter);
-      rewriter.restoreInsertionPoint(
-          bufferAssignment->computeAllocPosition(result.dyn_cast<OpResult>()));
-      Value alloc =
-          rewriter.create<AllocOp>(loc, newType.dyn_cast<MemRefType>());
-      newOperands.push_back(alloc);
-      replacingValues.push_back(alloc);
-    } else {
-      newResultTypes.push_back(currType);
-
-      // No replacing is required.
-      replacingValues.push_back(nullptr);
-    }
-  }
-
-  // Creating the new `CallOp`.
-  rewriter.create<CallOp>(loc, callOp.getCallee(), newResultTypes, newOperands);
-
-  // Replacing the results of the old `CallOp`.
-  rewriter.replaceOp(callOp, replacingValues);
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // BufferAssignmentTypeConverter
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
new file mode 100644
index 000000000000..adf6e30fe6c6
--- /dev/null
+++ b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt -test-buffer-placement-preparation-with-allowed-memref-results -split-input-file %s | FileCheck %s -dump-input-on-failure
+
+// Since allowMemrefEscaping is on for Buffer Placement in this test pass, all
+// tensor typed function results are converted to memref and remain as function
+// results. All memref typed function results will escape from the deallocation
+// phase of Buffer Placement.
+
+// CHECK-LABEL: func @void_function_signature_conversion
+func @void_function_signature_conversion(%arg0: tensor<4x8xf32>) {
+    return
+}
+// CHECK: ({{.*}}: memref<4x8xf32>)
+
+// -----
+
+#map0 = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: func @complex_signature_conversion
+func @complex_signature_conversion(%arg0: tensor<5xf32>, %arg1: memref<10xf32>, %arg2: i1, %arg3: f16) -> (i1, tensor<5xf32>, memref<10xf32>, memref<15xf32>, f16) {
+  %0 = alloc() : memref<15xf32>
+  %1 = linalg.generic {
+          args_in = 1 : i64,
+          args_out = 1 : i64,
+          indexing_maps = [#map0, #map0],
+          iterator_types = ["parallel"]
+        } %arg0 {
+        ^bb0(%gen1_arg0: f32):
+          %tmp1 = exp %gen1_arg0 : f32
+          linalg.yield %tmp1 : f32
+        }: tensor<5xf32> -> tensor<5xf32>
+  return %arg2, %1, %arg1, %0, %arg3 : i1, tensor<5xf32>, memref<10xf32>, memref<15xf32>, f16
+}
+//      CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, %[[ARG2:.*]]: i1, %[[ARG3:.*]]: f16)
+// CHECK-SAME: (i1, memref<5xf32>, memref<10xf32>, memref<15xf32>, f16)
+//      CHECK: %[[FIRST_ALLOC:.*]] = alloc()
+//      CHECK: %[[LINALG_ALLOC:.*]] = alloc()
+//      CHECK: return %[[ARG2]], %[[LINALG_ALLOC]], %[[ARG1]], %[[FIRST_ALLOC]], %[[ARG3]]
+
+// -----
+
+// CHECK-LABEL: func @no_signature_conversion_is_needed
+func @no_signature_conversion_is_needed(%arg0: memref<4x8xf32>) {
+  return
+}
+// CHECK: ({{.*}}: memref<4x8xf32>)
+
+// -----
+
+// CHECK-LABEL: func @no_signature_conversion_is_needed
+func @no_signature_conversion_is_needed(%arg0: i1, %arg1: f16) -> (i1, f16){
+  return %arg0, %arg1 : i1, f16
+}
+// CHECK: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: f16) -> (i1, f16)
+// CHECK: return %[[ARG0]], %[[ARG1]]
+
+// -----
+
+// CHECK-LABEL: func @simple_signature_conversion
+func @simple_signature_conversion(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
+  return %arg0 : tensor<4x8xf32>
+}
+//      CHECK: (%[[ARG0:.*]]: [[TYPE:.*]]<[[RANK:.*]]>) -> [[TYPE]]<[[RANK]]>
+// CHECK-NEXT: return %[[ARG0]]
+
+// -----
+
+// CHECK-LABEL: func @func_and_block_signature_conversion
+func @func_and_block_signature_conversion(%arg0 : tensor<2xf32>, %cond : i1, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32>{
+    cond_br %cond, ^bb1, ^bb2
+  ^bb1:
+    br ^exit(%arg0 : tensor<2xf32>)
+  ^bb2:
+    br ^exit(%arg0 : tensor<2xf32>)
+  ^exit(%arg2: tensor<2xf32>):
+    return %arg1 : tensor<4x4xf32>
+}
+//      CHECK: (%[[ARG0:.*]]: [[ARG0_TYPE:.*]], %[[COND:.*]]: i1, %[[ARG1:.*]]: [[ARG1_TYPE:.*]]) -> [[RESULT_TYPE:.*]]
+//      CHECK: br ^[[EXIT_BLOCK:.*]](%[[ARG0]] : [[ARG0_TYPE]])
+//      CHECK: br ^[[EXIT_BLOCK]](%[[ARG0]] : [[ARG0_TYPE]])
+//      CHECK: ^[[EXIT_BLOCK]](%{{.*}}: [[ARG0_TYPE]])
+// CHECK-NEXT:  return %[[ARG1]]
+
+// -----
+
+// CHECK-LABEL: func @callee
+func @callee(%arg1: tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>) {
+  %buff = alloc() : memref<2xf32>
+  return %arg1, %buff : tensor<5xf32>, memref<2xf32>
+}
+// CHECK: (%[[CALLEE_ARG:.*]]: memref<5xf32>) -> (memref<5xf32>, memref<2xf32>)
+// CHECK: %[[ALLOC:.*]] = alloc()
+// CHECK: return %[[CALLEE_ARG]], %[[ALLOC]]
+
+// CHECK-LABEL: func @caller
+func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
+  %x:2 = call @callee(%arg0) : (tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>)
+  %y:2 = call @callee(%x#0) : (tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>)
+  return %y#0 : tensor<5xf32>
+}
+// CHECK: (%[[CALLER_ARG:.*]]: memref<5xf32>) -> memref<5xf32>
+// CHECK: %[[X:.*]]:2 = call @callee(%[[CALLER_ARG]])
+// CHECK: %[[Y:.*]]:2 = call @callee(%[[X]]#0)
+// CHECK: return %[[Y]]#0
+
+
+
+
+

diff  --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir
index 5cde928e6da2..cae2829ead17 100644
--- a/mlir/test/Transforms/buffer-placement-preparation.mlir
+++ b/mlir/test/Transforms/buffer-placement-preparation.mlir
@@ -199,7 +199,7 @@ func @compute_allocs_position(%cond: i1, %arg0: tensor<2xf32>) -> tensor<2xf32>{
 // -----
 
 // Test case: Checking BufferAssignmentCallOpConverter and
-// FunctionAndBlockSignatureConverter and BufferAssignmentReturnOpConverter all
+// BufferAssignmentFuncOpConverter and BufferAssignmentReturnOpConverter all
 // together. The signature of `callee` after signature conversion would be:
 
 // func @callee(%arg0: memref<5xf32>,%arg1: memref<5xf32>) -> ()
@@ -246,7 +246,7 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
 // -----
 
 // Test case: Checking BufferAssignmentCallOpConverter and
-// FunctionAndBlockSignatureConverter and BufferAssignmentReturnOpConverter all
+// BufferAssignmentFuncOpConverter and BufferAssignmentReturnOpConverter all
 // together on functions that also have memref typed results. The signature of
 // `callee` after signature conversion would be:
 

diff  --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index aee12b37a687..3d0cc290e9fc 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -21,17 +21,22 @@
 using namespace mlir;
 
 namespace {
-/// This pass tests the computeAllocPosition helper method and two provided
-/// operation converters, FunctionAndBlockSignatureConverter and
-/// BufferAssignmentReturnOpConverter. Furthermore, this pass converts linalg
-/// operations on tensors to linalg operations on buffers to prepare them for
-/// the BufferPlacement pass that can be applied afterwards.
+/// This pass tests the computeAllocPosition helper method and buffer assignment
+/// operation converters. Furthermore, this pass converts linalg operations on
+/// tensors to linalg operations on buffers to prepare them for the
+/// BufferPlacement pass that can be applied afterwards.
+/// `allowMemrefFunctionResults` informs the buffer placement to allow functions
+/// that have memref typed results. Buffer assignment operation converters will
+/// be adapted respectively. It will also allow memref typed results to escape
+/// from the deallocation.
+template <bool allowMemrefFunctionResults>
 struct TestBufferPlacementPreparationPass
-    : mlir::PassWrapper<TestBufferPlacementPreparationPass,
-                        OperationPass<ModuleOp>> {
+    : mlir::PassWrapper<
+          TestBufferPlacementPreparationPass<allowMemrefFunctionResults>,
+          OperationPass<ModuleOp>> {
 
-  /// Converts tensor-type generic linalg operations to memref ones using buffer
-  /// assignment.
+  /// Converts tensor-type generic linalg operations to memref ones using
+  /// buffer assignment.
   class GenericOpConverter
       : public BufferAssignmentOpConversionPattern<linalg::GenericOp> {
   public:
@@ -104,19 +109,14 @@ struct TestBufferPlacementPreparationPass
   void populateTensorLinalgToBufferLinalgConversionPattern(
       MLIRContext *context, BufferAssignmentPlacer *placer,
       TypeConverter *converter, OwningRewritePatternList *patterns) {
-    // clang-format off
-    patterns->insert<
-                   BufferAssignmentCallOpConverter,
-                   FunctionAndBlockSignatureConverter,
-                   GenericOpConverter,
-                   BufferAssignmentReturnOpConverter<
-                      ReturnOp, ReturnOp, linalg::CopyOp>
-    >(context, placer, converter);
-    // clang-format on
+    populateWithBufferAssignmentOpConversionPatterns<
+        mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
+        allowMemrefFunctionResults>(context, placer, converter, patterns);
+    patterns->insert<GenericOpConverter>(context, placer, converter);
   }
 
   void runOnOperation() override {
-    MLIRContext &context = getContext();
+    MLIRContext &context = this->getContext();
     ConversionTarget target(context);
     BufferAssignmentTypeConverter converter;
 
@@ -150,7 +150,7 @@ struct TestBufferPlacementPreparationPass
     });
 
     // Walk over all the functions to apply buffer assignment.
-    getOperation().walk([&](FuncOp function) -> WalkResult {
+    this->getOperation().walk([&](FuncOp function) -> WalkResult {
       OwningRewritePatternList patterns;
       BufferAssignmentPlacer placer(function);
       populateTensorLinalgToBufferLinalgConversionPattern(
@@ -165,9 +165,18 @@ struct TestBufferPlacementPreparationPass
 
 namespace mlir {
 void registerTestBufferPlacementPreparationPass() {
-  PassRegistration<TestBufferPlacementPreparationPass>(
+  PassRegistration<
+      TestBufferPlacementPreparationPass</*allowMemrefFunctionResults=*/false>>(
       "test-buffer-placement-preparation",
       "Tests buffer placement helper methods including its "
       "operation-conversion patterns");
 }
-} // end namespace mlir
\ No newline at end of file
+
+void registerTestPreparationPassWithAllowedMemrefResults() {
+  PassRegistration<
+      TestBufferPlacementPreparationPass</*allowMemrefFunctionResults=*/true>>(
+      "test-buffer-placement-preparation-with-allowed-memref-results",
+      "Tests the helper operation converters of buffer placement for allowing "
+      "functions to have memref typed results.");
+}
+} // end namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2d286e112f99..067a2156c6fb 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -61,6 +61,7 @@ void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
 void registerTestOpaqueLoc();
 void registerTestParallelismDetection();
+void registerTestPreparationPassWithAllowedMemrefResults();
 void registerTestGpuParallelLoopMappingPass();
 void registerTestSCFUtilsPass();
 void registerTestVectorConversions();
@@ -133,6 +134,7 @@ void registerTestPasses() {
   registerTestMemRefStrideCalculation();
   registerTestOpaqueLoc();
   registerTestParallelismDetection();
+  registerTestPreparationPassWithAllowedMemrefResults();
   registerTestGpuParallelLoopMappingPass();
   registerTestSCFUtilsPass();
   registerTestVectorConversions();


        


More information about the Mlir-commits mailing list