[Mlir-commits] [mlir] 16cbe88 - [mlir][linalg][bufferize] Migrate --linalg-bufferize to BufferizableOpInterface-based bufferization
Matthias Springer
llvmlistbot at llvm.org
Thu Mar 3 03:13:05 PST 2022
Author: Matthias Springer
Date: 2022-03-03T20:12:37+09:00
New Revision: 16cbe883b57ceda7880b65bbeab83bff2493820a
URL: https://github.com/llvm/llvm-project/commit/16cbe883b57ceda7880b65bbeab83bff2493820a
DIFF: https://github.com/llvm/llvm-project/commit/16cbe883b57ceda7880b65bbeab83bff2493820a.diff
LOG: [mlir][linalg][bufferize] Migrate --linalg-bufferize to BufferizableOpInterface-based bufferization
This commit deletes the old dialect conversion-based bufferization patterns, which are now obsolete.
Differential Revision: https://reviews.llvm.org/D120883
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/test/Dialect/Linalg/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index ae03c4d36a192..593057d039b3c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -404,6 +404,17 @@ class AlwaysCopyBufferizationState : public BufferizationState {
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
ValueRange values);
+/// Lookup the buffer for the given value. If the value was not bufferized yet,
+/// wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp, from
+/// which the memref operand is returned.
+///
+/// Note: Use `BufferizationState::getBuffer` during bufferization.
+/// `lookupBuffer` is just for compatibility and gradual migration of
+/// bufferization patterns to BufferizableOpInterface-based bufferization. It
+/// does not insert any buffer copies.
+Value lookupBuffer(RewriterBase &rewriter, Value tensor,
+ const BufferizationOptions &options);
+
/// Replace an op with a new op. The new op must have the same number of
/// results as the replaced op. The new op may not return any tensor values.
template <typename OpTy, typename... Args>
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 54c1fa9968ac4..87f5cc17b4221 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -113,17 +113,6 @@ void populateFusePadTensorWithProducerLinalgOpPatterns(
/// canonicalizations of named ops into another named op.
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
-/// Populate the given list with patterns to bufferize linalg ops.
-void populateLinalgBufferizePatterns(
- bufferization::BufferizeTypeConverter &converter,
- RewritePatternSet &patterns);
-
-/// Create linalg op on buffers given the original tensor-based operation and
-/// the buffers for the outputs.
-LinalgOp createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
- LinalgOp linalgOp, ValueRange inputs,
- ValueRange outputs);
-
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
/// tensors.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index ee5a34bcc83b7..e5d94872ae586 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -212,8 +212,8 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
#endif
}
-static Value lookupBuffer(RewriterBase &rewriter, Value tensor,
- const BufferizationOptions &options) {
+Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
+ const BufferizationOptions &options) {
auto tensorType = tensor.getType().dyn_cast<TensorType>();
assert(tensorType && "unexpected non-tensor type");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index 510e14cb63917..84b16bcd7a3f7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -1,4 +1,4 @@
-//===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
+//===- Bufferize.cpp - Bufferization of linalg ops ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -8,208 +8,40 @@
#include "PassDetail.h"
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
-using namespace ::mlir;
-using namespace ::mlir::linalg;
-
-static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
- auto memrefType = memref.getType().cast<MemRefType>();
- auto alloc = b.create<memref::AllocOp>(loc, memrefType,
- getDynOperands(loc, memref, b));
- b.create<memref::CopyOp>(loc, memref, alloc);
- return alloc;
-}
-
-static LogicalResult
-allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
- SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
- // Lazily compute loopRanges.
- SmallVector<Range, 4> loopRanges;
-
- // Allocate a buffer for every tensor result.
- assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
- for (const auto &en : llvm::enumerate(linalgOp->getResultTypes())) {
- size_t resultIndex = en.index();
- Type resultType = en.value();
-
- auto tensorType = resultType.dyn_cast<RankedTensorType>();
- if (tensorType == nullptr) {
- linalgOp.emitOpError()
- << "tensor to buffer conversion expects ranked tensor results";
- return failure();
- }
- auto tensorShape = tensorType.getShape();
- auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
- Value resultTensor = outputs[resultIndex];
-
- // Clone output buffers whose value is actually used.
- OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex);
- if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) {
- resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
- continue;
- }
-
- // Allocate buffers for statically-shaped results.
- if (memrefType.hasStaticShape()) {
- resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
- continue;
- }
-
- resultBuffers.push_back(b.create<memref::AllocOp>(
- loc, memrefType, getDynOperands(loc, resultTensor, b)));
- }
- return success();
-}
-
-/// Create linalg op on buffers given the original tensor-based operation and
-/// the buffers for the outputs.
-LinalgOp
-mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
- LinalgOp linalgOp, ValueRange inputs,
- ValueRange outputs) {
- SmallVector<Value, 8> newOperands = inputs;
- newOperands.append(outputs.begin(), outputs.end());
- auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(),
- /*resultTypes=*/ArrayRef<Type>{},
- newOperands);
- for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) {
- auto &oldRegion = std::get<0>(regions);
- auto &newRegion = std::get<1>(regions);
- rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
- }
- return newOp;
-}
-
-//===----------------------------------------------------------------------===//
-// Bufferization patterns.
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-/// Conversion pattern that replaces `linalg.init_tensor` with allocation.
-class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
-public:
- using OpConversionPattern<InitTensorOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(InitTensorOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
- adaptor.sizes());
- return success();
- }
-};
-
-/// Conversion pattern that bufferizes `linalg.fill` operation.
-class BufferizeFillOp : public OpConversionPattern<FillOp> {
-public:
- using OpConversionPattern<FillOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(FillOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- if (!op.output().getType().isa<TensorType>())
- return rewriter.notifyMatchFailure(op,
- "operand must be of a tensor type");
-
- rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output());
- rewriter.replaceOp(op, adaptor.output());
-
- return success();
- }
-};
-
-/// Generic conversion pattern that matches any LinalgOp. This avoids template
-/// instantiating one pattern for each LinalgOp.
-class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
-public:
- using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern;
-
- LogicalResult
- matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- // GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
- if (!op->hasAttr("operand_segment_sizes"))
- return failure();
-
- // We abuse the GenericOpAdaptor here.
- // TODO: Manually create an Adaptor that captures inputs and outputs for all
- // linalg::LinalgOp interface ops.
- linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
-
- Location loc = op.getLoc();
- SmallVector<Value, 2> newOutputBuffers;
-
- if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
- newOutputBuffers, rewriter))) {
- return op.emitOpError()
- << "Failed to allocate buffers for tensor results.";
- }
- createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers);
- // Replace the results of the old op with the new output buffers.
- rewriter.replaceOp(op, newOutputBuffers);
- return success();
- }
-};
-} // namespace
+using namespace mlir;
+using namespace bufferization;
namespace {
/// Converts Linalg operations that work on tensor-type operands or results to
/// work on buffers.
struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
void runOnOperation() override {
- MLIRContext &context = getContext();
- ConversionTarget target(context);
- bufferization::BufferizeTypeConverter typeConverter;
-
- // Mark certain operations legal.
- target.addLegalDialect<arith::ArithmeticDialect, AffineDialect,
- memref::MemRefDialect, tensor::TensorDialect>();
- target.addIllegalOp<InitTensorOp>();
+ BufferizationOptions options = getPartialBufferizationOptions();
+ options.allowDialectInFilter<linalg::LinalgDialect>();
- // Mark all Linalg operations illegal as long as they work on tensors.
- auto isLegalOperation = [&](Operation *op) {
- return typeConverter.isLegal(op);
- };
- target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
-
- RewritePatternSet patterns(&context);
- populateLinalgBufferizePatterns(typeConverter, patterns);
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();
}
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
+ tensor::TensorDialect, linalg::LinalgDialect>();
+ linalg::registerBufferizableOpInterfaceExternalModels(registry);
+ }
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
return std::make_unique<LinalgBufferizePass>();
}
-
-void mlir::linalg::populateLinalgBufferizePatterns(
- bufferization::BufferizeTypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- // TODO: Drop this once tensor constants work in standard.
- // clang-format off
- patterns.add<
- BufferizeAnyLinalgOp,
- BufferizeFillOp,
- BufferizeInitTensorOp
- >(typeConverter, patterns.getContext());
- // clang-format on
-}
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 2edc104ccad36..e6d4f92b0a0f5 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -linalg-bufferize -canonicalize -cse -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -linalg-bufferize -canonicalize -cse -split-input-file %s | FileCheck %s
#map0 = affine_map<(d0) -> (d0)>
@@ -12,8 +12,8 @@
// CHECK: #map = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func @basic(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<4xf32>) -> tensor<4xf32> {
-// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<4xf32>
-// CHECK: %[[RESULT_MEMREF:.*]] = memref.alloc() : memref<4xf32>
+// CHECK-DAG: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<4xf32>
+// CHECK-DAG: %[[RESULT_MEMREF:.*]] = memref.alloc() {{.*}} : memref<4xf32>
// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
// CHECK-SAME: ins(%[[MEMREF]] : memref<4xf32>)
// CHECK-SAME: outs(%[[RESULT_MEMREF]] : memref<4xf32>) {
@@ -46,8 +46,8 @@ func @basic(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: #map = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func @init_tensor(
// CHECK-SAME: %[[IN:.*]]: tensor<?xf32>, %[[SIZE:.*]]: index)
-// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[IN]] : memref<?xf32>
-// CHECK: %[[OUT_BUF:.*]] = memref.alloc(%[[SIZE]]) : memref<?xf32>
+// CHECK-DAG: %[[MEMREF:.*]] = bufferization.to_memref %[[IN]] : memref<?xf32>
+// CHECK-DAG: %[[OUT_BUF:.*]] = memref.alloc(%[[SIZE]]) {{.*}} : memref<?xf32>
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[MEMREF]] : memref<?xf32>)
// CHECK-SAME: outs(%[[OUT_BUF]] : memref<?xf32>) {
@@ -71,8 +71,8 @@ func @init_tensor(%in : tensor<?xf32>, %size: index) -> tensor<?xf32> {
#map0 = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func @multiple_results
-// CHECK: %[[RESULT0:.*]] = memref.alloc() : memref<4xf32>
-// CHECK: %[[RESULT1:.*]] = memref.alloc() : memref<4xf32>
+// CHECK: %[[RESULT1:.*]] = memref.alloc() {{.*}} : memref<4xf32>
+// CHECK: %[[RESULT0:.*]] = memref.alloc() {{.*}} : memref<4xf32>
// CHECK: linalg.generic
// CHECK-SAME: ins(%{{.*}} : memref<4xf32>)
// CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref<4xf32>, memref<4xf32>)
@@ -101,11 +101,11 @@ func @multiple_results(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[MEMREF_ARG:.*]] = bufferization.to_memref %[[ARG]] : memref<?x?xf32>
// CHECK: %[[DIM0:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[RESULT0:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
-// CHECK: %[[RESULT1:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
+// CHECK: %[[RESULT1:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {{.*}} : memref<?x?xf32>
+// CHECK: %[[RESULT0:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {{.*}} : memref<?x?xf32>
+// CHECK: %[[MEMREF_ARG:.*]] = bufferization.to_memref %[[ARG]] : memref<?x?xf32>
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[MEMREF_ARG]] : memref<?x?xf32>)
// CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref<?x?xf32>, memref<?x?xf32>)
@@ -140,9 +140,9 @@ func @dynamic_results(%arg0: tensor<?x?xf32>)
// CHECK-LABEL: func @generic_with_init_tensor(
// CHECK-SAME: %[[ARG0_TENSOR:.*]]: tensor<2x3x4xvector<3x4xi4>>,
// CHECK-SAME: %[[ARG1_TENSOR:.*]]: tensor<3x2xf32>) -> tensor<3x2xf32> {
+// CHECK: %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<3x2xf32>
// CHECK-DAG: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0_TENSOR]] : memref<2x3x4xvector<3x4xi4>>
// CHECK-DAG: %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<3x2xf32>
-// CHECK: %[[INIT_BUFFER:.*]] = memref.alloc() : memref<3x2xf32>
// CHECK: memref.copy %[[ARG1_MEMREF]], %[[INIT_BUFFER]] : memref<3x2xf32> to memref<3x2xf32>
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[ARG0_MEMREF]] : memref<2x3x4xvector<3x4xi4>>)
@@ -166,9 +166,9 @@ func @generic_with_init_tensor(%arg0: tensor<2x3x4xvector<3x4xi4>>,
// CHECK-SAME: %[[IN:.*]]: tensor<?xf32>
func @bufferize_fill(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%c0 = arith.constant 0.0 : f32
- // CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[IN]] : memref<?xf32>
- // CHECK: linalg.fill(%cst, %[[MEMREF]]) : f32, memref<?xf32>
- // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<?xf32>
+ // CHECK: %[[ALLOC:.*]] = memref.alloc
+ // CHECK: linalg.fill(%cst, %[[ALLOC]]) : f32, memref<?xf32>
+ // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<?xf32>
// CHECK: return %[[TENSOR]]
%0 = linalg.fill(%c0, %arg0) : f32, tensor<?xf32> -> tensor<?xf32>
return %0 : tensor<?xf32>
@@ -179,10 +179,13 @@ func @bufferize_fill(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-LABEL: func @bufferize_dot
func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> {
%dot = linalg.dot ins(%in, %in : tensor<4xf32>, tensor<4xf32>)
- outs(%out : tensor<f32>) -> tensor<f32>
+ outs(%out : tensor<f32>) -> tensor<f32>
return %dot : tensor<f32>
+ // CHECK: %[[ALLOC:.*]] = memref.alloc
+ // TODO: The copy is not necessary.
+ // CHECK: memref.copy {{.*}}, %[[ALLOC]]
// CHECK: linalg.dot ins(%{{.*}}, %{{.*}} : memref<4xf32>, memref<4xf32>)
- // CHECK-SAME: outs(%[[OUT:.*]] : memref<f32>)
- // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[OUT]] : memref<f32>
+ // CHECK-SAME: outs(%[[ALLOC:.*]] : memref<f32>)
+ // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32>
// CHECK: return %[[OUT_TENSOR]]
}
More information about the Mlir-commits
mailing list