[Mlir-commits] [mlir] 4dd5f79 - [mlir][bufferize] Add argument materialization for bufferization
Stephan Herhut
llvmlistbot at llvm.org
Thu Nov 26 04:44:34 PST 2020
Author: Stephan Herhut
Date: 2020-11-26T13:43:44+01:00
New Revision: 4dd5f79f07022dbbff547f4aff13b27134331215
URL: https://github.com/llvm/llvm-project/commit/4dd5f79f07022dbbff547f4aff13b27134331215
DIFF: https://github.com/llvm/llvm-project/commit/4dd5f79f07022dbbff547f4aff13b27134331215.diff
LOG: [mlir][bufferize] Add argument materialization for bufferization
This enables partial bufferization that includes function signatures. To test this, this
change also makes the func-bufferize partial and adds a dedicated finalizing-bufferize pass.
Differential Revision: https://reviews.llvm.org/D92032
Added:
mlir/test/Dialect/Standard/func-bufferize-partial.mlir
Modified:
mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
mlir/include/mlir/Transforms/Passes.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
mlir/lib/Transforms/Bufferize.cpp
mlir/test/Dialect/Standard/func-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
index 5a1bc7b9716e..55da3af88758 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
@@ -26,6 +26,13 @@ void populateCallOpTypeConversionPattern(OwningRewritePatternList &patterns,
MLIRContext *ctx,
TypeConverter &converter);
+/// Add a pattern to the given pattern list to rewrite branch operations and
+/// `return` to use operands that have been legalized by the conversion
+/// framework. This can only be done if the branch operation implements the
+/// BranchOpInterface. Only needed for partial conversions.
+void populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
+ OwningRewritePatternList &patterns, MLIRContext *ctx,
+ TypeConverter &converter);
} // end namespace mlir
#endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
index 3be398fecb0c..9623dd14a296 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -25,28 +25,26 @@ def StdExpandOps : FunctionPass<"std-expand"> {
def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
let summary = "Bufferize func/call/return ops";
let description = [{
- A finalizing bufferize pass that bufferizes std.func and std.call ops.
+ A bufferize pass that bufferizes std.func and std.call ops.
Because this pass updates std.func ops, it must be a module pass. It is
useful to keep this pass separate from other bufferizations so that the
other ones can be run at function-level in parallel.
- This pass must be done atomically for two reasons:
- 1. This pass changes func op signatures, which requires atomically updating
- calls as well throughout the entire module.
- 2. This pass changes the type of block arguments, which requires that all
- successor arguments of predecessors be converted. Terminators are not
- a closed universe (and need not implement BranchOpInterface), and so we
- cannot in general rewrite them.
+ This pass must be done atomically because it changes func op signatures,
+ which requires atomically updating calls as well throughout the entire
+ module.
- Note, because this is a "finalizing" bufferize step, it can create
- invalid IR because it will not create materializations. To avoid this
- situation, the pass must only be run when the only SSA values of
- tensor type are:
- - block arguments
- - the result of tensor_load
- Other values of tensor type should be eliminated by earlier
- bufferization passes.
+ This pass also changes the type of block arguments, which requires that all
+ successor arguments of predecessors be converted. This is achieved by
+ rewriting terminators based on the information provided by the
+ `BranchOpInterface`.
+ As this pass rewrites function operations, it also rewrites the
+ corresponding return operations. Other return-like operations that
+ implement the `ReturnLike` trait are not rewritten in general, as they
+ require that the correspondign parent operation is also rewritten.
+ Finally, this pass fails for unknown terminators, as we cannot decide
+ whether they need rewriting.
}];
let constructor = "mlir::createFuncBufferizePass()";
}
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 2e3437a46611..77d98ce79cca 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,10 @@ std::unique_ptr<Pass>
createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024,
unsigned bitwidthOfIndexType = 64);
+/// Creates a pass that finalizes a partial bufferization by removing remaining
+/// tensor_load and tensor_to_memref operations.
+std::unique_ptr<FunctionPass> createFinalizingBufferizePass();
+
/// Creates a pass that converts memref function results to out-params.
std::unique_ptr<Pass> createBufferResultsToOutParamsPass();
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index da4ca24db499..29fe43fc0169 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -290,6 +290,22 @@ def Inliner : Pass<"inline"> {
];
}
+def FinalizingBufferize : FunctionPass<"finalizing-bufferize"> {
+ let summary = "Finalize a partial bufferization";
+ let description = [{
+ A bufferize pass that finalizes a partial bufferization by removing
+ remaining `tensor_load` and `tensor_to_memref` operations.
+
+ The removal of those operations is only possible if the operations only
+ exist in pairs, i.e., all uses of `tensor_load` operations are
+ `tensor_to_memref` operations.
+
+ This pass will fail if not all operations can be removed or if any operation
+ with tensor typed operands remains.
+ }];
+ let constructor = "mlir::createFinalizingBufferizePass()";
+}
+
def LocationSnapshot : Pass<"snapshot-op-locations"> {
let summary = "Generate new locations from the current IR";
let description = [{
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
index 4aadb72e6368..1aace4517f71 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
@@ -21,6 +21,8 @@ using namespace mlir;
namespace {
struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
+ using FuncBufferizeBase<FuncBufferizePass>::FuncBufferizeBase;
+
void runOnOperation() override {
auto module = getOperation();
auto *context = &getContext();
@@ -35,14 +37,42 @@ struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
typeConverter.isLegal(&op.getBody());
});
populateCallOpTypeConversionPattern(patterns, context, typeConverter);
- populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
- patterns);
- target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
+ target.addDynamicallyLegalOp<CallOp>(
+ [&](CallOp op) { return typeConverter.isLegal(op); });
- // If all result types are legal, and all block arguments are legal (ensured
- // by func conversion above), then all types in the program are legal.
+ populateBranchOpInterfaceAndReturnOpTypeConversionPattern(patterns, context,
+ typeConverter);
+ target.addLegalOp<ModuleOp, ModuleTerminatorOp, TensorLoadOp,
+ TensorToMemrefOp>();
+ target.addDynamicallyLegalOp<ReturnOp>(
+ [&](ReturnOp op) { return typeConverter.isLegal(op); });
+ // Mark terminators as legal if they have the ReturnLike trait or
+ // implement the BranchOpInterface and have valid types. If they do not
+ // implement the trait or interface, mark them as illegal no matter what.
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
- return typeConverter.isLegal(op->getResultTypes());
+ // If it is not a terminator, ignore it.
+ if (op->isKnownNonTerminator())
+ return true;
+ // If it is not the last operation in the block, also ignore it. We do
+ // this to handle unknown operations, as well.
+ Block *block = op->getBlock();
+ if (!block || &block->back() != op)
+ return true;
+ // ReturnLike operations have to be legalized with their parent. For
+ // return this is handled, for other ops they remain as is.
+ if (op->hasTrait<OpTrait::ReturnLike>())
+ return true;
+ // All successor operands of branch like operations must be rewritten.
+ if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
+ for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
+ auto successorOperands = branchOp.getSuccessorOperands(p);
+ if (successorOperands.hasValue() &&
+ !typeConverter.isLegal(successorOperands.getValue().getTypes()))
+ return false;
+ }
+ return true;
+ }
+ return false;
});
if (failed(applyFullConversion(module, target, std::move(patterns))))
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
index 9d8fceb16db3..07d7c59e192b 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
@@ -13,21 +13,19 @@
using namespace mlir;
namespace {
-// Converts the operand and result types of the Standard's CallOp, used together
-// with the FuncOpSignatureConversion.
+/// Converts the operand and result types of the Standard's CallOp, used
+/// together with the FuncOpSignatureConversion.
struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
- CallOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
- : OpConversionPattern(ctx), converter(converter) {}
+ using OpConversionPattern<CallOp>::OpConversionPattern;
/// Hook for derived classes to implement combined matching and rewriting.
LogicalResult
matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- FunctionType type = callOp.getCalleeType();
-
// Convert the original function results.
SmallVector<Type, 1> convertedResults;
- if (failed(converter.convertTypes(type.getResults(), convertedResults)))
+ if (failed(typeConverter->convertTypes(callOp.getResultTypes(),
+ convertedResults)))
return failure();
// Substitute with the new result types from the corresponding FuncType
@@ -36,14 +34,77 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
convertedResults, operands);
return success();
}
-
- /// The type converter to use when rewriting the signature.
- TypeConverter &converter;
};
} // end anonymous namespace
void mlir::populateCallOpTypeConversionPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx,
TypeConverter &converter) {
- patterns.insert<CallOpSignatureConversion>(ctx, converter);
+ patterns.insert<CallOpSignatureConversion>(converter, ctx);
+}
+
+namespace {
+/// Only needed to support partial conversion of functions where this pattern
+/// ensures that the branch operation arguments matches up with the succesor
+/// block arguments.
+class BranchOpInterfaceTypeConversion : public ConversionPattern {
+public:
+ BranchOpInterfaceTypeConversion(TypeConverter &typeConverter,
+ MLIRContext *ctx)
+ : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto branchOp = dyn_cast<BranchOpInterface>(op);
+ if (!branchOp)
+ return failure();
+
+ // For a branch operation, only some operands go to the target blocks, so
+ // only rewrite those.
+ SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
+ for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
+ succIdx < succEnd; ++succIdx) {
+ auto successorOperands = branchOp.getSuccessorOperands(succIdx);
+ if (!successorOperands)
+ continue;
+ for (int idx = successorOperands->getBeginOperandIndex(),
+ eidx = idx + successorOperands->size();
+ idx < eidx; ++idx) {
+ newOperands[idx] = operands[idx];
+ }
+ }
+ rewriter.updateRootInPlace(
+ op, [newOperands, op]() { op->setOperands(newOperands); });
+ return success();
+ }
+};
+} // end anonymous namespace
+
+namespace {
+/// Only needed to support partial conversion of functions where this pattern
+/// ensures that the branch operation arguments matches up with the succesor
+/// block arguments.
+class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
+public:
+ using OpConversionPattern<ReturnOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // For a return, all operands go to the results of the parent, so
+ // rewrite them all.
+ Operation *operation = op.getOperation();
+ rewriter.updateRootInPlace(
+ op, [operands, operation]() { operation->setOperands(operands); });
+ return success();
+ }
+};
+} // end anonymous namespace
+
+void mlir::populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
+ OwningRewritePatternList &patterns, MLIRContext *ctx,
+ TypeConverter &typeConverter) {
+ patterns.insert<BranchOpInterfaceTypeConversion, ReturnOpTypeConversion>(
+ typeConverter, ctx);
}
diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp
index ba622335a396..1811ac8bdfbc 100644
--- a/mlir/lib/Transforms/Bufferize.cpp
+++ b/mlir/lib/Transforms/Bufferize.cpp
@@ -7,7 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Bufferize.h"
+#include "PassDetail.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Transforms/Passes.h"
using namespace mlir;
@@ -15,6 +17,13 @@ using namespace mlir;
// BufferizeTypeConverter
//===----------------------------------------------------------------------===//
+static Value materializeTensorLoad(OpBuilder &builder, TensorType type,
+ ValueRange inputs, Location loc) {
+ assert(inputs.size() == 1);
+ assert(inputs[0].getType().isa<BaseMemRefType>());
+ return builder.create<TensorLoadOp>(loc, type, inputs[0]);
+}
+
/// Registers conversions into BufferizeTypeConverter
BufferizeTypeConverter::BufferizeTypeConverter() {
// Keep all types unchanged.
@@ -27,12 +36,8 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
addConversion([](UnrankedTensorType type) -> Type {
return UnrankedMemRefType::get(type.getElementType(), 0);
});
- addSourceMaterialization([](OpBuilder &builder, TensorType type,
- ValueRange inputs, Location loc) -> Value {
- assert(inputs.size() == 1);
- assert(inputs[0].getType().isa<BaseMemRefType>());
- return builder.create<TensorLoadOp>(loc, type, inputs[0]);
- });
+ addArgumentMaterialization(materializeTensorLoad);
+ addSourceMaterialization(materializeTensorLoad);
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
@@ -83,3 +88,37 @@ void mlir::populateEliminateBufferizeMaterializationsPatterns(
patterns.insert<BufferizeTensorLoadOp, BufferizeTensorToMemrefOp>(
typeConverter, context);
}
+
+namespace {
+struct FinalizingBufferizePass
+ : public FinalizingBufferizeBase<FinalizingBufferizePass> {
+ using FinalizingBufferizeBase<
+ FinalizingBufferizePass>::FinalizingBufferizeBase;
+
+ void runOnFunction() override {
+ auto func = getFunction();
+ auto *context = &getContext();
+
+ BufferizeTypeConverter typeConverter;
+ OwningRewritePatternList patterns;
+ ConversionTarget target(*context);
+
+ populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
+ patterns);
+ target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
+
+ // If all result types are legal, and all block arguments are legal (ensured
+ // by func conversion above), then all types in the program are legal.
+ target.markUnknownOpDynamicallyLegal([&](Operation *op) {
+ return typeConverter.isLegal(op->getResultTypes());
+ });
+
+ if (failed(applyFullConversion(func, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+std::unique_ptr<FunctionPass> mlir::createFinalizingBufferizePass() {
+ return std::make_unique<FinalizingBufferizePass>();
+}
diff --git a/mlir/test/Dialect/Standard/func-bufferize-partial.mlir b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir
new file mode 100644
index 000000000000..2afa5327e572
--- /dev/null
+++ b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir
@@ -0,0 +1,59 @@
+// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics --debug-only=dialect-conversion | FileCheck %s
+
+// CHECK-LABEL: func @block_arguments(
+// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK: %[[T1:.*]] = tensor_load %[[ARG]] : memref<f32>
+// CHECK: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref<f32>
+// CHECK: br ^bb1(%[[M1]] : memref<f32>)
+// CHECK: ^bb1(%[[BBARG:.*]]: memref<f32>):
+// CHECK: %[[T2:.*]] = tensor_load %[[BBARG]] : memref<f32>
+// CHECK: %[[M2:.*]] = tensor_to_memref %[[T2]] : memref<f32>
+// CHECK: return %[[M2]] : memref<f32>
+func @block_arguments(%arg0: tensor<f32>) -> tensor<f32> {
+ br ^bb1(%arg0: tensor<f32>)
+^bb1(%bbarg: tensor<f32>):
+ return %bbarg : tensor<f32>
+}
+
+// CHECK-LABEL: func @partial()
+// CHECK-SAME: memref<f32>
+func @partial() -> tensor<f32> {
+ // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor<f32>
+ // CHECK-NEXT: %[[MEM:.*]] = tensor_to_memref %[[SRC]] : memref<f32>
+ %0 = "test.source"() : () -> tensor<f32>
+ // CHECK-NEXT: return %[[MEM]] : memref<f32>
+ return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: func @region_op
+// CHECK-SAME: (%[[ARG0:.*]]: i1) -> memref<f32>
+func @region_op(%arg0: i1) -> tensor<f32> {
+ // CHECK-NEXT: %[[IF:.*]] = scf.if %[[ARG0]] -> (tensor<f32>)
+ %0 = scf.if %arg0 -> (tensor<f32>) {
+ // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor<f32>
+ %1 = "test.source"() : () -> tensor<f32>
+ // CHECK-NEXT: scf.yield %[[SRC]] : tensor<f32>
+ scf.yield %1 : tensor<f32>
+ // CHECK-NEXT: else
+ } else {
+ // CHECK-NEXT: %[[OSRC:.*]] = "test.other_source"() : () -> tensor<f32>
+ %1 = "test.other_source"() : () -> tensor<f32>
+ // CHECK-NEXT: scf.yield %[[OSRC]] : tensor<f32>
+ scf.yield %1 : tensor<f32>
+ }
+ // CHECK: %[[MEM:.*]] = tensor_to_memref %[[IF]] : memref<f32>
+ // CHECK: return %[[MEM]] : memref<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func @failed_to_legalize(%arg0: tensor<f32>) -> tensor<f32> {
+ %0 = constant true
+ cond_br %0, ^bb1(%arg0: tensor<f32>), ^bb2(%arg0: tensor<f32>)
+ ^bb1(%bbarg0: tensor<f32>):
+ // expected-error @+1 {{failed to legalize operation 'test.terminator'}}
+ "test.terminator"() : () -> ()
+ ^bb2(%bbarg1: tensor<f32>):
+ return %bbarg1 : tensor<f32>
+}
diff --git a/mlir/test/Dialect/Standard/func-bufferize.mlir b/mlir/test/Dialect/Standard/func-bufferize.mlir
index 61c5e184cd17..d02db99aecd8 100644
--- a/mlir/test/Dialect/Standard/func-bufferize.mlir
+++ b/mlir/test/Dialect/Standard/func-bufferize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -func-bufferize -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @identity(
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
More information about the Mlir-commits
mailing list