[Mlir-commits] [mlir] 4dd5f79 - [mlir][bufferize] Add argument materialization for bufferization

Stephan Herhut llvmlistbot at llvm.org
Thu Nov 26 04:44:34 PST 2020


Author: Stephan Herhut
Date: 2020-11-26T13:43:44+01:00
New Revision: 4dd5f79f07022dbbff547f4aff13b27134331215

URL: https://github.com/llvm/llvm-project/commit/4dd5f79f07022dbbff547f4aff13b27134331215
DIFF: https://github.com/llvm/llvm-project/commit/4dd5f79f07022dbbff547f4aff13b27134331215.diff

LOG: [mlir][bufferize] Add argument materialization for bufferization

This enables partial bufferization that includes function signatures. To test this, this
change also makes the func-bufferize partial and adds a dedicated finalizing-bufferize pass.

Differential Revision: https://reviews.llvm.org/D92032

Added: 
    mlir/test/Dialect/Standard/func-bufferize-partial.mlir

Modified: 
    mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
    mlir/include/mlir/Transforms/Passes.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
    mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
    mlir/lib/Transforms/Bufferize.cpp
    mlir/test/Dialect/Standard/func-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
index 5a1bc7b9716e..55da3af88758 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
@@ -26,6 +26,13 @@ void populateCallOpTypeConversionPattern(OwningRewritePatternList &patterns,
                                          MLIRContext *ctx,
                                          TypeConverter &converter);
 
+/// Add a pattern to the given pattern list to rewrite branch operations and
+/// `return` to use operands that have been legalized by the conversion
+/// framework. This can only be done if the branch operation implements the
+/// BranchOpInterface. Only needed for partial conversions.
+void populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
+    OwningRewritePatternList &patterns, MLIRContext *ctx,
+    TypeConverter &converter);
 } // end namespace mlir
 
 #endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
index 3be398fecb0c..9623dd14a296 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -25,28 +25,26 @@ def StdExpandOps : FunctionPass<"std-expand"> {
 def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
   let summary = "Bufferize func/call/return ops";
   let description = [{
-    A finalizing bufferize pass that bufferizes std.func and std.call ops.
+    A bufferize pass that bufferizes std.func and std.call ops.
 
     Because this pass updates std.func ops, it must be a module pass. It is
     useful to keep this pass separate from other bufferizations so that the
     other ones can be run at function-level in parallel.
 
-    This pass must be done atomically for two reasons:
-    1. This pass changes func op signatures, which requires atomically updating
-       calls as well throughout the entire module.
-    2. This pass changes the type of block arguments, which requires that all
-       successor arguments of predecessors be converted. Terminators are not
-       a closed universe (and need not implement BranchOpInterface), and so we
-       cannot in general rewrite them.
+    This pass must be done atomically because it changes func op signatures,
+    which requires atomically updating calls as well throughout the entire
+    module.
 
-    Note, because this is a "finalizing" bufferize step, it can create
-    invalid IR because it will not create materializations. To avoid this
-    situation, the pass must only be run when the only SSA values of
-    tensor type are:
-    - block arguments
-    - the result of tensor_load
-    Other values of tensor type should be eliminated by earlier
-    bufferization passes.
+    This pass also changes the type of block arguments, which requires that all
+    successor arguments of predecessors be converted. This is achieved by
+    rewriting terminators based on the information provided by the
+    `BranchOpInterface`.
+    As this pass rewrites function operations, it also rewrites the
+    corresponding return operations. Other return-like operations that
+    implement the `ReturnLike` trait are not rewritten in general, as they
+    require that the correspondign parent operation is also rewritten.
+    Finally, this pass fails for unknown terminators, as we cannot decide
+    whether they need rewriting. 
   }];
   let constructor = "mlir::createFuncBufferizePass()";
 }

diff  --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 2e3437a46611..77d98ce79cca 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,10 @@ std::unique_ptr<Pass>
 createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024,
                                 unsigned bitwidthOfIndexType = 64);
 
+/// Creates a pass that finalizes a partial bufferization by removing remaining
+/// tensor_load and tensor_to_memref operations.
+std::unique_ptr<FunctionPass> createFinalizingBufferizePass();
+
 /// Creates a pass that converts memref function results to out-params.
 std::unique_ptr<Pass> createBufferResultsToOutParamsPass();
 

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index da4ca24db499..29fe43fc0169 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -290,6 +290,22 @@ def Inliner : Pass<"inline"> {
   ];
 }
 
+def FinalizingBufferize : FunctionPass<"finalizing-bufferize"> {
+  let summary = "Finalize a partial bufferization";
+  let description = [{
+    A bufferize pass that finalizes a partial bufferization by removing
+    remaining `tensor_load` and `tensor_to_memref` operations.
+
+    The removal of those operations is only possible if the operations only
+    exist in pairs, i.e., all uses of `tensor_load` operations are
+    `tensor_to_memref` operations.
+
+    This pass will fail if not all operations can be removed or if any operation
+    with tensor typed operands remains.
+  }];
+  let constructor = "mlir::createFinalizingBufferizePass()";
+}
+
 def LocationSnapshot : Pass<"snapshot-op-locations"> {
   let summary = "Generate new locations from the current IR";
   let description = [{

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
index 4aadb72e6368..1aace4517f71 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
@@ -21,6 +21,8 @@ using namespace mlir;
 
 namespace {
 struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
+  using FuncBufferizeBase<FuncBufferizePass>::FuncBufferizeBase;
+
   void runOnOperation() override {
     auto module = getOperation();
     auto *context = &getContext();
@@ -35,14 +37,42 @@ struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
              typeConverter.isLegal(&op.getBody());
     });
     populateCallOpTypeConversionPattern(patterns, context, typeConverter);
-    populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
-                                                       patterns);
-    target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
+    target.addDynamicallyLegalOp<CallOp>(
+        [&](CallOp op) { return typeConverter.isLegal(op); });
 
-    // If all result types are legal, and all block arguments are legal (ensured
-    // by func conversion above), then all types in the program are legal.
+    populateBranchOpInterfaceAndReturnOpTypeConversionPattern(patterns, context,
+                                                              typeConverter);
+    target.addLegalOp<ModuleOp, ModuleTerminatorOp, TensorLoadOp,
+                      TensorToMemrefOp>();
+    target.addDynamicallyLegalOp<ReturnOp>(
+        [&](ReturnOp op) { return typeConverter.isLegal(op); });
+    // Mark terminators as legal if they have the ReturnLike trait or
+    // implement the BranchOpInterface and have valid types. If they do not
+    // implement the trait or interface, mark them as illegal no matter what.
     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
-      return typeConverter.isLegal(op->getResultTypes());
+      // If it is not a terminator, ignore it.
+      if (op->isKnownNonTerminator())
+        return true;
+      // If it is not the last operation in the block, also ignore it. We do
+      // this to handle unknown operations, as well.
+      Block *block = op->getBlock();
+      if (!block || &block->back() != op)
+        return true;
+      // ReturnLike operations have to be legalized with their parent. For
+      // return this is handled, for other ops they remain as is.
+      if (op->hasTrait<OpTrait::ReturnLike>())
+        return true;
+      // All successor operands of branch like operations must be rewritten.
+      if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
+        for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
+          auto successorOperands = branchOp.getSuccessorOperands(p);
+          if (successorOperands.hasValue() &&
+              !typeConverter.isLegal(successorOperands.getValue().getTypes()))
+            return false;
+        }
+        return true;
+      }
+      return false;
     });
 
     if (failed(applyFullConversion(module, target, std::move(patterns))))

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
index 9d8fceb16db3..07d7c59e192b 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
@@ -13,21 +13,19 @@
 using namespace mlir;
 
 namespace {
-// Converts the operand and result types of the Standard's CallOp, used together
-// with the FuncOpSignatureConversion.
+/// Converts the operand and result types of the Standard's CallOp, used
+/// together with the FuncOpSignatureConversion.
 struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
-  CallOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
-      : OpConversionPattern(ctx), converter(converter) {}
+  using OpConversionPattern<CallOp>::OpConversionPattern;
 
   /// Hook for derived classes to implement combined matching and rewriting.
   LogicalResult
   matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    FunctionType type = callOp.getCalleeType();
-
     // Convert the original function results.
     SmallVector<Type, 1> convertedResults;
-    if (failed(converter.convertTypes(type.getResults(), convertedResults)))
+    if (failed(typeConverter->convertTypes(callOp.getResultTypes(),
+                                           convertedResults)))
       return failure();
 
     // Substitute with the new result types from the corresponding FuncType
@@ -36,14 +34,77 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
                                         convertedResults, operands);
     return success();
   }
-
-  /// The type converter to use when rewriting the signature.
-  TypeConverter &converter;
 };
 } // end anonymous namespace
 
 void mlir::populateCallOpTypeConversionPattern(
     OwningRewritePatternList &patterns, MLIRContext *ctx,
     TypeConverter &converter) {
-  patterns.insert<CallOpSignatureConversion>(ctx, converter);
+  patterns.insert<CallOpSignatureConversion>(converter, ctx);
+}
+
+namespace {
+/// Only needed to support partial conversion of functions where this pattern
+/// ensures that the branch operation arguments matches up with the succesor
+/// block arguments.
+class BranchOpInterfaceTypeConversion : public ConversionPattern {
+public:
+  BranchOpInterfaceTypeConversion(TypeConverter &typeConverter,
+                                  MLIRContext *ctx)
+      : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto branchOp = dyn_cast<BranchOpInterface>(op);
+    if (!branchOp)
+      return failure();
+
+    // For a branch operation, only some operands go to the target blocks, so
+    // only rewrite those.
+    SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
+    for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
+         succIdx < succEnd; ++succIdx) {
+      auto successorOperands = branchOp.getSuccessorOperands(succIdx);
+      if (!successorOperands)
+        continue;
+      for (int idx = successorOperands->getBeginOperandIndex(),
+               eidx = idx + successorOperands->size();
+           idx < eidx; ++idx) {
+        newOperands[idx] = operands[idx];
+      }
+    }
+    rewriter.updateRootInPlace(
+        op, [newOperands, op]() { op->setOperands(newOperands); });
+    return success();
+  }
+};
+} // end anonymous namespace
+
+namespace {
+/// Only needed to support partial conversion of functions where this pattern
+/// ensures that the branch operation arguments matches up with the succesor
+/// block arguments.
+class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
+public:
+  using OpConversionPattern<ReturnOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    // For a return, all operands go to the results of the parent, so
+    // rewrite them all.
+    Operation *operation = op.getOperation();
+    rewriter.updateRootInPlace(
+        op, [operands, operation]() { operation->setOperands(operands); });
+    return success();
+  }
+};
+} // end anonymous namespace
+
+void mlir::populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
+    OwningRewritePatternList &patterns, MLIRContext *ctx,
+    TypeConverter &typeConverter) {
+  patterns.insert<BranchOpInterfaceTypeConversion, ReturnOpTypeConversion>(
+      typeConverter, ctx);
 }

diff  --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp
index ba622335a396..1811ac8bdfbc 100644
--- a/mlir/lib/Transforms/Bufferize.cpp
+++ b/mlir/lib/Transforms/Bufferize.cpp
@@ -7,7 +7,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Transforms/Bufferize.h"
+#include "PassDetail.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/Transforms/Passes.h"
 
 using namespace mlir;
 
@@ -15,6 +17,13 @@ using namespace mlir;
 // BufferizeTypeConverter
 //===----------------------------------------------------------------------===//
 
+static Value materializeTensorLoad(OpBuilder &builder, TensorType type,
+                                   ValueRange inputs, Location loc) {
+  assert(inputs.size() == 1);
+  assert(inputs[0].getType().isa<BaseMemRefType>());
+  return builder.create<TensorLoadOp>(loc, type, inputs[0]);
+}
+
 /// Registers conversions into BufferizeTypeConverter
 BufferizeTypeConverter::BufferizeTypeConverter() {
   // Keep all types unchanged.
@@ -27,12 +36,8 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
   addConversion([](UnrankedTensorType type) -> Type {
     return UnrankedMemRefType::get(type.getElementType(), 0);
   });
-  addSourceMaterialization([](OpBuilder &builder, TensorType type,
-                              ValueRange inputs, Location loc) -> Value {
-    assert(inputs.size() == 1);
-    assert(inputs[0].getType().isa<BaseMemRefType>());
-    return builder.create<TensorLoadOp>(loc, type, inputs[0]);
-  });
+  addArgumentMaterialization(materializeTensorLoad);
+  addSourceMaterialization(materializeTensorLoad);
   addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
                               ValueRange inputs, Location loc) -> Value {
     assert(inputs.size() == 1);
@@ -83,3 +88,37 @@ void mlir::populateEliminateBufferizeMaterializationsPatterns(
   patterns.insert<BufferizeTensorLoadOp, BufferizeTensorToMemrefOp>(
       typeConverter, context);
 }
+
+namespace {
+struct FinalizingBufferizePass
+    : public FinalizingBufferizeBase<FinalizingBufferizePass> {
+  using FinalizingBufferizeBase<
+      FinalizingBufferizePass>::FinalizingBufferizeBase;
+
+  void runOnFunction() override {
+    auto func = getFunction();
+    auto *context = &getContext();
+
+    BufferizeTypeConverter typeConverter;
+    OwningRewritePatternList patterns;
+    ConversionTarget target(*context);
+
+    populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
+                                                       patterns);
+    target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
+
+    // If all result types are legal, and all block arguments are legal (ensured
+    // by func conversion above), then all types in the program are legal.
+    target.markUnknownOpDynamicallyLegal([&](Operation *op) {
+      return typeConverter.isLegal(op->getResultTypes());
+    });
+
+    if (failed(applyFullConversion(func, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<FunctionPass> mlir::createFinalizingBufferizePass() {
+  return std::make_unique<FinalizingBufferizePass>();
+}

diff  --git a/mlir/test/Dialect/Standard/func-bufferize-partial.mlir b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir
new file mode 100644
index 000000000000..2afa5327e572
--- /dev/null
+++ b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir
@@ -0,0 +1,59 @@
+// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics --debug-only=dialect-conversion | FileCheck %s
+
+// CHECK-LABEL:   func @block_arguments(
+// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK:           %[[T1:.*]] = tensor_load %[[ARG]] : memref<f32>
+// CHECK:           %[[M1:.*]] = tensor_to_memref %[[T1]] : memref<f32>
+// CHECK:           br ^bb1(%[[M1]] : memref<f32>)
+// CHECK:         ^bb1(%[[BBARG:.*]]: memref<f32>):
+// CHECK:           %[[T2:.*]] = tensor_load %[[BBARG]] : memref<f32>
+// CHECK:           %[[M2:.*]] = tensor_to_memref %[[T2]] : memref<f32>
+// CHECK:           return %[[M2]] : memref<f32>
+func @block_arguments(%arg0: tensor<f32>) -> tensor<f32> {
+  br ^bb1(%arg0: tensor<f32>)
+^bb1(%bbarg: tensor<f32>):
+  return %bbarg : tensor<f32>
+}
+
+// CHECK-LABEL: func @partial()
+// CHECK-SAME: memref<f32>
+func @partial() -> tensor<f32> {
+  // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor<f32>
+  // CHECK-NEXT: %[[MEM:.*]] = tensor_to_memref %[[SRC]] : memref<f32>
+  %0 = "test.source"() : () -> tensor<f32>
+  // CHECK-NEXT: return %[[MEM]] : memref<f32>
+  return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: func @region_op
+// CHECK-SAME: (%[[ARG0:.*]]: i1) -> memref<f32>
+func @region_op(%arg0: i1) -> tensor<f32> {
+  // CHECK-NEXT: %[[IF:.*]] = scf.if %[[ARG0]] -> (tensor<f32>)
+  %0 = scf.if %arg0 -> (tensor<f32>) {
+    // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor<f32>
+    %1 = "test.source"() : () -> tensor<f32>
+    // CHECK-NEXT: scf.yield %[[SRC]] : tensor<f32>
+    scf.yield %1 : tensor<f32>
+  // CHECK-NEXT: else
+  } else {
+    // CHECK-NEXT: %[[OSRC:.*]] = "test.other_source"() : () -> tensor<f32>
+    %1 = "test.other_source"() : () -> tensor<f32>
+    // CHECK-NEXT: scf.yield %[[OSRC]] : tensor<f32>
+    scf.yield %1 : tensor<f32>
+  }
+  // CHECK: %[[MEM:.*]] = tensor_to_memref %[[IF]] : memref<f32>
+  // CHECK: return %[[MEM]] : memref<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func @failed_to_legalize(%arg0: tensor<f32>) -> tensor<f32> {
+    %0 = constant true
+    cond_br %0, ^bb1(%arg0: tensor<f32>), ^bb2(%arg0: tensor<f32>)
+  ^bb1(%bbarg0: tensor<f32>):
+    // expected-error @+1 {{failed to legalize operation 'test.terminator'}}
+    "test.terminator"() : () -> ()
+  ^bb2(%bbarg1: tensor<f32>):
+    return %bbarg1 : tensor<f32>
+}

diff  --git a/mlir/test/Dialect/Standard/func-bufferize.mlir b/mlir/test/Dialect/Standard/func-bufferize.mlir
index 61c5e184cd17..d02db99aecd8 100644
--- a/mlir/test/Dialect/Standard/func-bufferize.mlir
+++ b/mlir/test/Dialect/Standard/func-bufferize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -func-bufferize -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s
 
 // CHECK-LABEL:   func @identity(
 // CHECK-SAME:        %[[ARG:.*]]: memref<f32>) -> memref<f32> {


        


More information about the Mlir-commits mailing list