[Mlir-commits] [mlir] 3468300 - [MLIR] Update the FunctionAndBlockSignatureConverter and NonVoidToVoidReturnOpConverter of Buffer Assignment

Ehsan Toosi llvmlistbot at llvm.org
Tue May 19 08:06:01 PDT 2020


Author: Ehsan Toosi
Date: 2020-05-19T17:04:59+02:00
New Revision: 346830051105a849d7fc3ceb246e65acbc0264ae

URL: https://github.com/llvm/llvm-project/commit/346830051105a849d7fc3ceb246e65acbc0264ae
DIFF: https://github.com/llvm/llvm-project/commit/346830051105a849d7fc3ceb246e65acbc0264ae.diff

LOG: [MLIR] Update the FunctionAndBlockSignatureConverter and NonVoidToVoidReturnOpConverter of Buffer Assignment

Making these two converters more generic. FunctionAndBlockSignatureConverter now
moves only memref results (after type conversion) to the function argument and
keeps other legal function results unchanged. NonVoidToVoidReturnOpConverter is
renamed to NoBufferOperandsReturnOpConverter. It removes only the buffer
operands from the operands of the converted ReturnOp and inserts CopyOps to copy
each buffer to the target function argument.

Differential Revision: https://reviews.llvm.org/D79329

Added: 
    mlir/test/Transforms/buffer-placement-preparation.mlir

Modified: 
    mlir/include/mlir/Transforms/BufferPlacement.h
    mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
    mlir/lib/Transforms/BufferPlacement.cpp
    mlir/test/lib/Transforms/TestBufferPlacement.cpp

Removed: 
    mlir/test/Transforms/buffer-placement-prepration.mlir


################################################################################
diff  --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h
index 013a55f81f60..030b87599d06 100644
--- a/mlir/include/mlir/Transforms/BufferPlacement.h
+++ b/mlir/include/mlir/Transforms/BufferPlacement.h
@@ -76,12 +76,11 @@ class BufferAssignmentOpConversionPattern
   TypeConverter *converter;
 };
 
-/// This conversion adds an extra argument for each function result which makes
-/// the converted function a void function. A type converter must be provided
-/// for this conversion to convert a non-shaped type to memref.
-/// BufferAssignmentTypeConverter is an helper TypeConverter for this
-/// purpose. All the non-shaped type of the input function will be converted to
-/// memref.
+/// Converts the signature of the function using the type converter.
+/// It adds an extra argument for each illegally-typed function
+/// result to the function arguments. `BufferAssignmentTypeConverter`
+/// is a helper `TypeConverter` for this purpose. All the non-shaped types
+/// of the input function will be converted to memref.
 class FunctionAndBlockSignatureConverter
     : public BufferAssignmentOpConversionPattern<FuncOp> {
 public:
@@ -94,12 +93,12 @@ class FunctionAndBlockSignatureConverter
                   ConversionPatternRewriter &rewriter) const final;
 };
 
-/// This pattern converter transforms a non-void ReturnOpSourceTy into a void
-/// return of type ReturnOpTargetTy. It uses a copy operation of type CopyOpTy
-/// to copy the results to the output buffer.
+/// Converts the source `ReturnOp` to target `ReturnOp`, removes all
+/// the buffer operands from the operands list, and inserts `CopyOp`s
+/// for all buffer operands instead.
 template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
           typename CopyOpTy>
-class NonVoidToVoidReturnOpConverter
+class NoBufferOperandsReturnOpConverter
     : public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> {
 public:
   using BufferAssignmentOpConversionPattern<
@@ -109,29 +108,38 @@ class NonVoidToVoidReturnOpConverter
   LogicalResult
   matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    unsigned numReturnValues = returnOp.getNumOperands();
     Block &entryBlock = returnOp.getParentRegion()->front();
     unsigned numFuncArgs = entryBlock.getNumArguments();
     Location loc = returnOp.getLoc();
 
-    // Find the corresponding output buffer for each operand.
-    assert(numReturnValues <= numFuncArgs &&
-           "The number of operands of return operation is more than the "
-           "number of function argument.");
-    unsigned firstReturnParameter = numFuncArgs - numReturnValues;
-    for (auto operand : llvm::enumerate(operands)) {
-      unsigned returnArgNumber = firstReturnParameter + operand.index();
-      BlockArgument dstBuffer = entryBlock.getArgument(returnArgNumber);
-      if (dstBuffer == operand.value())
-        continue;
-
-      // Insert the copy operation to copy before the return.
-      rewriter.setInsertionPoint(returnOp);
-      rewriter.create<CopyOpTy>(loc, operand.value(),
-                                entryBlock.getArgument(returnArgNumber));
-    }
-    // Insert the new target return operation.
-    rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp);
+    // The target `ReturnOp` should not contain any memref operands.
+    SmallVector<Value, 2> newOperands(operands.begin(), operands.end());
+    llvm::erase_if(newOperands, [](Value operand) {
+      return operand.getType().isa<MemRefType>();
+    });
+
+    // Find the index of the first destination buffer.
+    unsigned numBufferOperands = operands.size() - newOperands.size();
+    unsigned destArgNum = numFuncArgs - numBufferOperands;
+
+    rewriter.setInsertionPoint(returnOp);
+    // Find the corresponding destination buffer for each memref operand.
+    for (Value operand : operands)
+      if (operand.getType().isa<MemRefType>()) {
+        assert(destArgNum < numFuncArgs &&
+               "The number of operands of return operation is more than the "
+               "number of function argument.");
+
+        // For each memref type operand of the source `ReturnOp`, a new `CopyOp`
+        // is inserted that copies the buffer content from the operand to the
+        // target.
+        rewriter.create<CopyOpTy>(loc, operand,
+                                  entryBlock.getArgument(destArgNum));
+        ++destArgNum;
+      }
+
+    // Insert the new target Return operation.
+    rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, newOperands);
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index 93501011a416..9b5855dff0ce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -21,8 +21,8 @@
 
 using namespace mlir;
 using ReturnOpConverter =
-    NonVoidToVoidReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
-                                   linalg::CopyOp>;
+    NoBufferOperandsReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
+                                      linalg::CopyOp>;
 
 namespace {
 /// A pattern to convert Generic Linalg operations which work on tensors to
@@ -132,30 +132,6 @@ struct ConvertLinalgOnTensorsToBuffers
         Optional<ConversionTarget::DynamicLegalityCallbackFn>(
             isLegalOperation));
 
-    // TODO: Considering the following dynamic legality checks, the current
-    // implementation of FunctionAndBlockSignatureConverter of Buffer Assignment
-    // will convert the function signature incorrectly. This converter moves
-    // all the return values of the function to the input argument list without
-    // considering the return value types and creates a void function. However,
-    // the NonVoidToVoidReturnOpConverter doesn't change the return operation if
-    // its operands are not tensors. The following example leaves the IR in a
-    // broken state.
-    //
-    // @function(%arg0: f32, %arg1: tensor<4xf32>) -> (f32, f32) {
-    //    %0 = mulf %arg0, %arg0 : f32
-    //    return %0, %0 : f32, f32
-    // }
-    //
-    // broken IR after conversion:
-    //
-    // func @function(%arg0: f32, %arg1: memref<4xf32>, f32, f32) {
-    //    %0 = mulf %arg0, %arg0 : f32
-    //    return %0, %0 : f32, f32
-    // }
-    //
-    // This issue must be fixed in FunctionAndBlockSignatureConverter and
-    // NonVoidToVoidReturnOpConverter.
-
     // Mark Standard Return operations illegal as long as one operand is tensor.
     target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
       return llvm::none_of(returnOp.getOperandTypes(), isIllegalType);

diff  --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp
index 24c228ec0657..cd0641c1ac32 100644
--- a/mlir/lib/Transforms/BufferPlacement.cpp
+++ b/mlir/lib/Transforms/BufferPlacement.cpp
@@ -43,7 +43,8 @@
 // The current implementation does not support loops and the resulting code will
 // be invalid with respect to program semantics. The only thing that is
 // currently missing is a high-level loop analysis that allows us to move allocs
-// and deallocs outside of the loop blocks.
+// and deallocs outside of the loop blocks. Furthermore, it doesn't also accept
+// functions which return buffers already.
 //
 //===----------------------------------------------------------------------===//
 
@@ -429,19 +430,39 @@ LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite(
                      "FunctionAndBlockSignatureConverter");
     return failure();
   }
-  // Converting shaped type arguments to memref type.
   auto funcType = funcOp.getType();
+  TypeRange resultTypes = funcType.getResults();
+  if (llvm::any_of(resultTypes,
+                   [](Type type) { return type.isa<MemRefType>(); }))
+    return funcOp.emitError("BufferAssignmentPlacer doesn't currently support "
+                            "functions which return memref typed values");
+
+  // Convert function arguments using the provided TypeConverter.
   TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
   for (auto argType : llvm::enumerate(funcType.getInputs()))
     conversion.addInputs(argType.index(),
                          converter->convertType(argType.value()));
-  // Adding function results to the arguments of the converted function as
-  // memref type. The converted function will be a void function.
-  for (Type resType : funcType.getResults())
-    conversion.addInputs(converter->convertType((resType)));
+
+  // Adding a function argument for each function result which is going to be a
+  // memref type after type conversion.
+  SmallVector<Type, 2> newResultTypes;
+  newResultTypes.reserve(funcOp.getNumResults());
+  for (Type resType : resultTypes) {
+    Type convertedType = converter->convertType(resType);
+
+    // If the result type is memref after the type conversion, a new argument
+    // should be appended to the function arguments list for this result.
+    // Otherwise, it remains unchanged as a function result.
+    if (convertedType.isa<MemRefType>())
+      conversion.addInputs(convertedType);
+    else
+      newResultTypes.push_back(convertedType);
+  }
+
+  // Update the signature of the function.
   rewriter.updateRootInPlace(funcOp, [&] {
-    funcOp.setType(
-        rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None));
+    funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
+                                            newResultTypes));
     rewriter.applySignatureConversion(&funcOp.getBody(), conversion);
   });
   return success();

diff  --git a/mlir/test/Transforms/buffer-placement-prepration.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir
similarity index 80%
rename from mlir/test/Transforms/buffer-placement-prepration.mlir
rename to mlir/test/Transforms/buffer-placement-preparation.mlir
index 76212538aa3f..ef7a2e328da5 100644
--- a/mlir/test/Transforms/buffer-placement-prepration.mlir
+++ b/mlir/test/Transforms/buffer-placement-preparation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-buffer-placement-preparation -split-input-file %s | FileCheck %s -dump-input-on-failure
+// RUN: mlir-opt -test-buffer-placement-preparation -split-input-file -verify-diagnostics %s | FileCheck %s -dump-input-on-failure
 
 // CHECK-LABEL: func @func_signature_conversion
 func @func_signature_conversion(%arg0: tensor<4x8xf32>) {
@@ -8,6 +8,44 @@ func @func_signature_conversion(%arg0: tensor<4x8xf32>) {
 
 // -----
 
+// expected-error @below {{BufferAssignmentPlacer doesn't currently support functions which return memref typed values}}
+// expected-error @below {{failed to legalize operation 'func'}}
+func @memref_in_function_results(%arg0: tensor<4x8xf32>) -> (tensor<4x8xf32>, memref<5xf32>) {
+  %0 = alloc() : memref<5xf32>
+  return %arg0, %0 : tensor<4x8xf32>, memref<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @no_signature_conversion_is_needed
+func @no_signature_conversion_is_needed(%arg0: memref<4x8xf32>) {
+  return
+}
+// CHECK: ({{.*}}: memref<4x8xf32>) {
+
+// -----
+
+// CHECK-LABEL: func @no_signature_conversion_is_needed
+func @no_signature_conversion_is_needed(%arg0: i1, %arg1: f16) -> (i1, f16){
+  return %arg0, %arg1 : i1, f16
+}
+// CHECK: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: f16) -> (i1, f16)
+// CHECK: return %[[ARG0]], %[[ARG1]]
+
+// -----
+
+// CHECK-LABEL: func @complex_signature_conversion
+func @complex_signature_conversion(%arg0: tensor<4x8xf32>, %arg1: i1, %arg2: tensor<5x5xf64>,%arg3: f16) -> (i1, tensor<5x5xf64>, f16, tensor<4x8xf32>) {
+    return %arg1, %arg2, %arg3, %arg0 : i1, tensor<5x5xf64>, f16, tensor<4x8xf32>
+}
+//      CHECK: (%[[ARG0:.*]]: memref<4x8xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5x5xf64>, %[[ARG3:.*]]: f16,
+// CHECK-SAME: %[[RESULT1:.*]]: memref<5x5xf64>, %[[RESULT2:.*]]: memref<4x8xf32>) -> (i1, f16) {
+// CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]])
+// CHECK-NEXT: linalg.copy(%[[ARG0]], %[[RESULT2]])
+// CHECK-NEXT: return %[[ARG1]], %[[ARG3]]
+
+// -----
+
 // CHECK-LABEL: func @non_void_to_void_return_op_converter
 func @non_void_to_void_return_op_converter(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
   return %arg0 : tensor<4x8xf32>

diff  --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index 03c6a2a72d50..2d781e64cdfa 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -23,7 +23,7 @@ using namespace mlir;
 namespace {
 /// This pass tests the computeAllocPosition helper method and two provided
 /// operation converters, FunctionAndBlockSignatureConverter and
-/// NonVoidToVoidReturnOpConverter. Furthermore, this pass converts linalg
+/// NoBufferOperandsReturnOpConverter. Furthermore, this pass converts linalg
 /// operations on tensors to linalg operations on buffers to prepare them for
 /// the BufferPlacement pass that can be applied afterwards.
 struct TestBufferPlacementPreparationPass
@@ -82,7 +82,6 @@ struct TestBufferPlacementPreparationPass
         auto type = result.getType().cast<ShapedType>();
         entryBlock.addArgument(type.getElementType());
       }
-
       rewriter.eraseOp(op);
       return success();
     }
@@ -95,7 +94,7 @@ struct TestBufferPlacementPreparationPass
     patterns->insert<
                    FunctionAndBlockSignatureConverter,
                    GenericOpConverter,
-                   NonVoidToVoidReturnOpConverter<
+                   NoBufferOperandsReturnOpConverter<
                       ReturnOp, ReturnOp, linalg::CopyOp>
     >(context, placer, converter);
     // clang-format on
@@ -105,8 +104,9 @@ struct TestBufferPlacementPreparationPass
     auto &context = getContext();
     ConversionTarget target(context);
     BufferAssignmentTypeConverter converter;
-    // Make all linalg operations illegal as long as they work on tensors.
     target.addLegalDialect<StandardOpsDialect>();
+
+    // Make all linalg operations illegal as long as they work on tensors.
     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(
         Optional<ConversionTarget::DynamicLegalityCallbackFn>(
             [&](Operation *op) {
@@ -117,9 +117,12 @@ struct TestBufferPlacementPreparationPass
                      llvm::none_of(op->getResultTypes(), isIllegalType);
             }));
 
-    // Mark return operations illegal as long as they return values.
-    target.addDynamicallyLegalOp<mlir::ReturnOp>(
-        [](mlir::ReturnOp returnOp) { return returnOp.getNumOperands() == 0; });
+    // Mark std.ReturnOp illegal as long as an operand is tensor or buffer.
+    target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
+      return llvm::none_of(returnOp.getOperandTypes(), [&](Type type) {
+        return type.isa<MemRefType>() || !converter.isLegal(type);
+      });
+    });
 
     // Mark the function whose arguments are in tensor-type illegal.
     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {


        


More information about the Mlir-commits mailing list